1 #ifndef CAFFE_UTIL_CUDNN_H_
2 #define CAFFE_UTIL_CUDNN_H_
7 #include "caffe/common.hpp"
8 #include "caffe/proto/caffe.pb.h"
10 #define CUDNN_VERSION_MIN(major, minor, patch) \
11 (CUDNN_VERSION >= (major * 1000 + minor * 100 + patch))
13 #define CUDNN_CHECK(condition) \
15 cudnnStatus_t status = condition; \
16 CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " "\
17 << cudnnGetErrorString(status); \
20 inline const char* cudnnGetErrorString(cudnnStatus_t status) {
22 case CUDNN_STATUS_SUCCESS:
23 return "CUDNN_STATUS_SUCCESS";
24 case CUDNN_STATUS_NOT_INITIALIZED:
25 return "CUDNN_STATUS_NOT_INITIALIZED";
26 case CUDNN_STATUS_ALLOC_FAILED:
27 return "CUDNN_STATUS_ALLOC_FAILED";
28 case CUDNN_STATUS_BAD_PARAM:
29 return "CUDNN_STATUS_BAD_PARAM";
30 case CUDNN_STATUS_INTERNAL_ERROR:
31 return "CUDNN_STATUS_INTERNAL_ERROR";
32 case CUDNN_STATUS_INVALID_VALUE:
33 return "CUDNN_STATUS_INVALID_VALUE";
34 case CUDNN_STATUS_ARCH_MISMATCH:
35 return "CUDNN_STATUS_ARCH_MISMATCH";
36 case CUDNN_STATUS_MAPPING_ERROR:
37 return "CUDNN_STATUS_MAPPING_ERROR";
38 case CUDNN_STATUS_EXECUTION_FAILED:
39 return "CUDNN_STATUS_EXECUTION_FAILED";
40 case CUDNN_STATUS_NOT_SUPPORTED:
41 return "CUDNN_STATUS_NOT_SUPPORTED";
42 case CUDNN_STATUS_LICENSE_ERROR:
43 return "CUDNN_STATUS_LICENSE_ERROR";
44 #if CUDNN_VERSION_MIN(6, 0, 0)
45 case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
46 return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
48 #if CUDNN_VERSION_MIN(7, 0, 0)
49 case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
50 return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
51 case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
52 return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
55 return "Unknown cudnn status";
62 template <
typename Dtype>
class dataType;
63 template<>
class dataType<float> {
65 static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
66 static float oneval, zeroval;
67 static const void *one, *zero;
69 template<>
class dataType<double> {
71 static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
72 static double oneval, zeroval;
73 static const void *one, *zero;
76 template <
typename Dtype>
77 inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) {
78 CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
81 template <
typename Dtype>
82 inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
83 int n,
int c,
int h,
int w,
84 int stride_n,
int stride_c,
int stride_h,
int stride_w) {
85 CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
86 n, c, h, w, stride_n, stride_c, stride_h, stride_w));
89 template <
typename Dtype>
90 inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
91 int n,
int c,
int h,
int w) {
92 const int stride_w = 1;
93 const int stride_h = w * stride_w;
94 const int stride_c = h * stride_h;
95 const int stride_n = c * stride_c;
96 setTensor4dDesc<Dtype>(desc, n, c, h, w,
97 stride_n, stride_c, stride_h, stride_w);
100 template <
typename Dtype>
101 inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
102 int n,
int c,
int h,
int w) {
103 CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
104 #if CUDNN_VERSION_MIN(5, 0, 0)
105 CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
106 CUDNN_TENSOR_NCHW, n, c, h, w));
108 CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(*desc, dataType<Dtype>::type,
109 CUDNN_TENSOR_NCHW, n, c, h, w));
113 template <
typename Dtype>
114 inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) {
115 CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv));
118 template <
typename Dtype>
119 inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
120 cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
121 int pad_h,
int pad_w,
int stride_h,
int stride_w) {
122 #if CUDNN_VERSION_MIN(6, 0, 0)
123 CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
124 pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION,
125 dataType<Dtype>::type));
127 CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
128 pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
132 template <
typename Dtype>
133 inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
134 PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
135 int h,
int w,
int pad_h,
int pad_w,
int stride_h,
int stride_w) {
136 switch (poolmethod) {
137 case PoolingParameter_PoolMethod_MAX:
138 *mode = CUDNN_POOLING_MAX;
140 case PoolingParameter_PoolMethod_AVE:
141 *mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
144 LOG(FATAL) <<
"Unknown pooling method.";
146 CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
147 #if CUDNN_VERSION_MIN(5, 0, 0)
148 CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode,
149 CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
151 CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(*pool_desc, *mode,
152 CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
156 template <
typename Dtype>
157 inline void createActivationDescriptor(cudnnActivationDescriptor_t* activ_desc,
158 cudnnActivationMode_t mode) {
159 CUDNN_CHECK(cudnnCreateActivationDescriptor(activ_desc));
160 CUDNN_CHECK(cudnnSetActivationDescriptor(*activ_desc, mode,
161 CUDNN_PROPAGATE_NAN, Dtype(0)));
169 #endif // CAFFE_UTIL_CUDNN_H_