1 #ifndef CAFFE_CUDNN_LRN_LAYER_HPP_
2 #define CAFFE_CUDNN_LRN_LAYER_HPP_
6 #include "caffe/blob.hpp"
7 #include "caffe/layer.hpp"
8 #include "caffe/proto/caffe.pb.h"
10 #include "caffe/layers/lrn_layer.hpp"
15 template <
typename Dtype>
16 class CuDNNLRNLayer :
public LRNLayer<Dtype> {
18 explicit CuDNNLRNLayer(
const LayerParameter& param)
19 : LRNLayer<Dtype>(param), handles_setup_(false) {}
20 virtual void LayerSetUp(
const vector<Blob<Dtype>*>& bottom,
21 const vector<Blob<Dtype>*>& top);
22 virtual void Reshape(
const vector<Blob<Dtype>*>& bottom,
23 const vector<Blob<Dtype>*>& top);
24 virtual ~CuDNNLRNLayer();
27 virtual void Forward_gpu(
const vector<Blob<Dtype>*>& bottom,
28 const vector<Blob<Dtype>*>& top);
29 virtual void Backward_gpu(
const vector<Blob<Dtype>*>& top,
30 const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom);
33 cudnnHandle_t handle_;
34 cudnnLRNDescriptor_t norm_desc_;
35 cudnnTensorDescriptor_t bottom_desc_, top_desc_;
38 Dtype alpha_, beta_, k_;
44 #endif // CAFFE_CUDNN_LRN_LAYER_HPP_