Caffe
All Classes Namespaces Functions Variables Typedefs
cudnn.hpp
1 #ifndef CAFFE_UTIL_CUDNN_H_
2 #define CAFFE_UTIL_CUDNN_H_
3 #ifdef USE_CUDNN
4 
5 #include <cudnn.h>
6 
7 #include "caffe/common.hpp"
8 #include "caffe/proto/caffe.pb.h"
9 
10 #define CUDNN_VERSION_MIN(major, minor, patch) \
11  (CUDNN_VERSION >= (major * 1000 + minor * 100 + patch))
12 
13 #define CUDNN_CHECK(condition) \
14  do { \
15  cudnnStatus_t status = condition; \
16  CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " "\
17  << cudnnGetErrorString(status); \
18  } while (0)
19 
20 inline const char* cudnnGetErrorString(cudnnStatus_t status) {
21  switch (status) {
22  case CUDNN_STATUS_SUCCESS:
23  return "CUDNN_STATUS_SUCCESS";
24  case CUDNN_STATUS_NOT_INITIALIZED:
25  return "CUDNN_STATUS_NOT_INITIALIZED";
26  case CUDNN_STATUS_ALLOC_FAILED:
27  return "CUDNN_STATUS_ALLOC_FAILED";
28  case CUDNN_STATUS_BAD_PARAM:
29  return "CUDNN_STATUS_BAD_PARAM";
30  case CUDNN_STATUS_INTERNAL_ERROR:
31  return "CUDNN_STATUS_INTERNAL_ERROR";
32  case CUDNN_STATUS_INVALID_VALUE:
33  return "CUDNN_STATUS_INVALID_VALUE";
34  case CUDNN_STATUS_ARCH_MISMATCH:
35  return "CUDNN_STATUS_ARCH_MISMATCH";
36  case CUDNN_STATUS_MAPPING_ERROR:
37  return "CUDNN_STATUS_MAPPING_ERROR";
38  case CUDNN_STATUS_EXECUTION_FAILED:
39  return "CUDNN_STATUS_EXECUTION_FAILED";
40  case CUDNN_STATUS_NOT_SUPPORTED:
41  return "CUDNN_STATUS_NOT_SUPPORTED";
42  case CUDNN_STATUS_LICENSE_ERROR:
43  return "CUDNN_STATUS_LICENSE_ERROR";
44 #if CUDNN_VERSION_MIN(6, 0, 0)
45  case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
46  return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
47 #endif
48 #if CUDNN_VERSION_MIN(7, 0, 0)
49  case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
50  return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
51  case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
52  return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
53 #endif
54  }
55  return "Unknown cudnn status";
56 }
57 
58 namespace caffe {
59 
60 namespace cudnn {
61 
62 template <typename Dtype> class dataType;
63 template<> class dataType<float> {
64  public:
65  static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
66  static float oneval, zeroval;
67  static const void *one, *zero;
68 };
69 template<> class dataType<double> {
70  public:
71  static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
72  static double oneval, zeroval;
73  static const void *one, *zero;
74 };
75 
76 template <typename Dtype>
77 inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) {
78  CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
79 }
80 
81 template <typename Dtype>
82 inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
83  int n, int c, int h, int w,
84  int stride_n, int stride_c, int stride_h, int stride_w) {
85  CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
86  n, c, h, w, stride_n, stride_c, stride_h, stride_w));
87 }
88 
89 template <typename Dtype>
90 inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
91  int n, int c, int h, int w) {
92  const int stride_w = 1;
93  const int stride_h = w * stride_w;
94  const int stride_c = h * stride_h;
95  const int stride_n = c * stride_c;
96  setTensor4dDesc<Dtype>(desc, n, c, h, w,
97  stride_n, stride_c, stride_h, stride_w);
98 }
99 
100 template <typename Dtype>
101 inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
102  int n, int c, int h, int w) {
103  CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
104 #if CUDNN_VERSION_MIN(5, 0, 0)
105  CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
106  CUDNN_TENSOR_NCHW, n, c, h, w));
107 #else
108  CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(*desc, dataType<Dtype>::type,
109  CUDNN_TENSOR_NCHW, n, c, h, w));
110 #endif
111 }
112 
113 template <typename Dtype>
114 inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) {
115  CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv));
116 }
117 
118 template <typename Dtype>
119 inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
120  cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
121  int pad_h, int pad_w, int stride_h, int stride_w) {
122 #if CUDNN_VERSION_MIN(6, 0, 0)
123  CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
124  pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION,
125  dataType<Dtype>::type));
126 #else
127  CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
128  pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
129 #endif
130 }
131 
132 template <typename Dtype>
133 inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
134  PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
135  int h, int w, int pad_h, int pad_w, int stride_h, int stride_w) {
136  switch (poolmethod) {
137  case PoolingParameter_PoolMethod_MAX:
138  *mode = CUDNN_POOLING_MAX;
139  break;
140  case PoolingParameter_PoolMethod_AVE:
141  *mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
142  break;
143  default:
144  LOG(FATAL) << "Unknown pooling method.";
145  }
146  CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
147 #if CUDNN_VERSION_MIN(5, 0, 0)
148  CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode,
149  CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
150 #else
151  CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(*pool_desc, *mode,
152  CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
153 #endif
154 }
155 
156 template <typename Dtype>
157 inline void createActivationDescriptor(cudnnActivationDescriptor_t* activ_desc,
158  cudnnActivationMode_t mode) {
159  CUDNN_CHECK(cudnnCreateActivationDescriptor(activ_desc));
160  CUDNN_CHECK(cudnnSetActivationDescriptor(*activ_desc, mode,
161  CUDNN_PROPAGATE_NAN, Dtype(0)));
162 }
163 
164 } // namespace cudnn
165 
166 } // namespace caffe
167 
168 #endif // USE_CUDNN
169 #endif // CAFFE_UTIL_CUDNN_H_
caffe
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14