1 #ifndef CAFFE_COMMON_HPP_
2 #define CAFFE_COMMON_HPP_
4 #include <boost/shared_ptr.hpp>
5 #include <gflags/gflags.h>
6 #include <glog/logging.h>
19 #include "caffe/util/device_alternate.hpp"
22 #define STRINGIFY(m) #m
23 #define AS_STRING(m) STRINGIFY(m)
30 #ifndef GFLAGS_GFLAGS_H_
31 namespace gflags = google;
32 #endif // GFLAGS_GFLAGS_H_
35 #define DISABLE_COPY_AND_ASSIGN(classname) \
37 classname(const classname&);\
38 classname& operator=(const classname&)
41 #define INSTANTIATE_CLASS(classname) \
42 char gInstantiationGuard##classname; \
43 template class classname<float>; \
44 template class classname<double>
46 #define INSTANTIATE_LAYER_GPU_FORWARD(classname) \
47 template void classname<float>::Forward_gpu( \
48 const std::vector<Blob<float>*>& bottom, \
49 const std::vector<Blob<float>*>& top); \
50 template void classname<double>::Forward_gpu( \
51 const std::vector<Blob<double>*>& bottom, \
52 const std::vector<Blob<double>*>& top);
54 #define INSTANTIATE_LAYER_GPU_BACKWARD(classname) \
55 template void classname<float>::Backward_gpu( \
56 const std::vector<Blob<float>*>& top, \
57 const std::vector<bool>& propagate_down, \
58 const std::vector<Blob<float>*>& bottom); \
59 template void classname<double>::Backward_gpu( \
60 const std::vector<Blob<double>*>& top, \
61 const std::vector<bool>& propagate_down, \
62 const std::vector<Blob<double>*>& bottom)
64 #define INSTANTIATE_LAYER_GPU_FUNCS(classname) \
65 INSTANTIATE_LAYER_GPU_FORWARD(classname); \
66 INSTANTIATE_LAYER_GPU_BACKWARD(classname)
70 #define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
73 namespace cv {
class Mat; }
79 using boost::shared_ptr;
89 using std::ostringstream;
93 using std::stringstream;
98 void GlobalInit(
int* pargc,
char*** pargv);
111 enum Brew { CPU, GPU };
118 explicit RNG(
unsigned int seed);
120 RNG& operator=(
const RNG&);
124 shared_ptr<Generator> generator_;
128 inline static RNG& rng_stream() {
129 if (!Get().random_generator_) {
130 Get().random_generator_.reset(
new RNG());
132 return *(Get().random_generator_);
135 inline static cublasHandle_t cublas_handle() {
return Get().cublas_handle_; }
136 inline static curandGenerator_t curand_generator() {
137 return Get().curand_generator_;
142 inline static Brew mode() {
return Get().mode_; }
148 inline static void set_mode(Brew mode) { Get().mode_ = mode; }
150 static void set_random_seed(
const unsigned int seed);
153 static void SetDevice(
const int device_id);
155 static void DeviceQuery();
157 static bool CheckDevice(
const int device_id);
160 static int FindDevice(
const int start_id = 0);
162 inline static int solver_count() {
return Get().solver_count_; }
163 inline static void set_solver_count(
int val) { Get().solver_count_ = val; }
164 inline static int solver_rank() {
return Get().solver_rank_; }
165 inline static void set_solver_rank(
int val) { Get().solver_rank_ = val; }
166 inline static bool multiprocess() {
return Get().multiprocess_; }
167 inline static void set_multiprocess(
bool val) { Get().multiprocess_ = val; }
168 inline static bool root_solver() {
return Get().solver_rank_ == 0; }
172 cublasHandle_t cublas_handle_;
173 curandGenerator_t curand_generator_;
175 shared_ptr<RNG> random_generator_;
188 DISABLE_COPY_AND_ASSIGN(Caffe);
193 #endif // CAFFE_COMMON_HPP_