Caffe
All Classes Namespaces Functions Variables Typedefs
filler.hpp
1 // Fillers are random number generators that fills a blob using the specified
2 // algorithm. The expectation is that they are only going to be used during
3 // initialization time and will not involve any GPUs.
4 
5 #ifndef CAFFE_FILLER_HPP
6 #define CAFFE_FILLER_HPP
7 
8 #include <string>
9 
10 #include "caffe/blob.hpp"
11 #include "caffe/proto/caffe.pb.h"
12 #include "caffe/syncedmem.hpp"
13 #include "caffe/util/math_functions.hpp"
14 
15 namespace caffe {
16 
18 template <typename Dtype>
19 class Filler {
20  public:
21  explicit Filler(const FillerParameter& param) : filler_param_(param) {}
22  virtual ~Filler() {}
23  virtual void Fill(Blob<Dtype>* blob) = 0;
24  protected:
25  FillerParameter filler_param_;
26 }; // class Filler
27 
28 
30 template <typename Dtype>
31 class ConstantFiller : public Filler<Dtype> {
32  public:
33  explicit ConstantFiller(const FillerParameter& param)
34  : Filler<Dtype>(param) {}
35  virtual void Fill(Blob<Dtype>* blob) {
36  Dtype* data = blob->mutable_cpu_data();
37  const int count = blob->count();
38  const Dtype value = this->filler_param_.value();
39  CHECK(count);
40  for (int i = 0; i < count; ++i) {
41  data[i] = value;
42  }
43  CHECK_EQ(this->filler_param_.sparse(), -1)
44  << "Sparsity not supported by this Filler.";
45  }
46 };
47 
49 template <typename Dtype>
50 class UniformFiller : public Filler<Dtype> {
51  public:
52  explicit UniformFiller(const FillerParameter& param)
53  : Filler<Dtype>(param) {}
54  virtual void Fill(Blob<Dtype>* blob) {
55  CHECK(blob->count());
56  caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),
57  Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
58  CHECK_EQ(this->filler_param_.sparse(), -1)
59  << "Sparsity not supported by this Filler.";
60  }
61 };
62 
64 template <typename Dtype>
65 class GaussianFiller : public Filler<Dtype> {
66  public:
67  explicit GaussianFiller(const FillerParameter& param)
68  : Filler<Dtype>(param) {}
69  virtual void Fill(Blob<Dtype>* blob) {
70  Dtype* data = blob->mutable_cpu_data();
71  CHECK(blob->count());
72  caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),
73  Dtype(this->filler_param_.std()), blob->mutable_cpu_data());
74  int sparse = this->filler_param_.sparse();
75  CHECK_GE(sparse, -1);
76  if (sparse >= 0) {
77  // Sparse initialization is implemented for "weight" blobs; i.e. matrices.
78  // These have num == channels == 1; width is number of inputs; height is
79  // number of outputs. The 'sparse' variable specifies the mean number
80  // of non-zero input weights for a given output.
81  CHECK_GE(blob->num_axes(), 1);
82  const int num_outputs = blob->shape(0);
83  Dtype non_zero_probability = Dtype(sparse) / Dtype(num_outputs);
84  rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int)));
85  int* mask = reinterpret_cast<int*>(rand_vec_->mutable_cpu_data());
86  caffe_rng_bernoulli(blob->count(), non_zero_probability, mask);
87  for (int i = 0; i < blob->count(); ++i) {
88  data[i] *= mask[i];
89  }
90  }
91  }
92 
93  protected:
94  shared_ptr<SyncedMemory> rand_vec_;
95 };
96 
100 template <typename Dtype>
101 class PositiveUnitballFiller : public Filler<Dtype> {
102  public:
103  explicit PositiveUnitballFiller(const FillerParameter& param)
104  : Filler<Dtype>(param) {}
105  virtual void Fill(Blob<Dtype>* blob) {
106  Dtype* data = blob->mutable_cpu_data();
107  DCHECK(blob->count());
108  caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
109  // We expect the filler to not be called very frequently, so we will
110  // just use a simple implementation
111  int dim = blob->count() / blob->shape(0);
112  CHECK(dim);
113  for (int i = 0; i < blob->shape(0); ++i) {
114  Dtype sum = 0;
115  for (int j = 0; j < dim; ++j) {
116  sum += data[i * dim + j];
117  }
118  for (int j = 0; j < dim; ++j) {
119  data[i * dim + j] /= sum;
120  }
121  }
122  CHECK_EQ(this->filler_param_.sparse(), -1)
123  << "Sparsity not supported by this Filler.";
124  }
125 };
126 
143 template <typename Dtype>
144 class XavierFiller : public Filler<Dtype> {
145  public:
146  explicit XavierFiller(const FillerParameter& param)
147  : Filler<Dtype>(param) {}
148  virtual void Fill(Blob<Dtype>* blob) {
149  CHECK(blob->count());
150  int fan_in = blob->count() / blob->shape(0);
151  // Compatibility with ND blobs
152  int fan_out = blob->num_axes() > 1 ?
153  blob->count() / blob->shape(1) :
154  blob->count();
155  Dtype n = fan_in; // default to fan_in
156  if (this->filler_param_.variance_norm() ==
157  FillerParameter_VarianceNorm_AVERAGE) {
158  n = (fan_in + fan_out) / Dtype(2);
159  } else if (this->filler_param_.variance_norm() ==
160  FillerParameter_VarianceNorm_FAN_OUT) {
161  n = fan_out;
162  }
163  Dtype scale = sqrt(Dtype(3) / n);
164  caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
165  blob->mutable_cpu_data());
166  CHECK_EQ(this->filler_param_.sparse(), -1)
167  << "Sparsity not supported by this Filler.";
168  }
169 };
170 
188 template <typename Dtype>
189 class MSRAFiller : public Filler<Dtype> {
190  public:
191  explicit MSRAFiller(const FillerParameter& param)
192  : Filler<Dtype>(param) {}
193  virtual void Fill(Blob<Dtype>* blob) {
194  CHECK(blob->count());
195  int fan_in = blob->count() / blob->shape(0);
196  // Compatibility with ND blobs
197  int fan_out = blob->num_axes() > 1 ?
198  blob->count() / blob->shape(1) :
199  blob->count();
200  Dtype n = fan_in; // default to fan_in
201  if (this->filler_param_.variance_norm() ==
202  FillerParameter_VarianceNorm_AVERAGE) {
203  n = (fan_in + fan_out) / Dtype(2);
204  } else if (this->filler_param_.variance_norm() ==
205  FillerParameter_VarianceNorm_FAN_OUT) {
206  n = fan_out;
207  }
208  Dtype std = sqrt(Dtype(2) / n);
209  caffe_rng_gaussian<Dtype>(blob->count(), Dtype(0), std,
210  blob->mutable_cpu_data());
211  CHECK_EQ(this->filler_param_.sparse(), -1)
212  << "Sparsity not supported by this Filler.";
213  }
214 };
215 
249 template <typename Dtype>
250 class BilinearFiller : public Filler<Dtype> {
251  public:
252  explicit BilinearFiller(const FillerParameter& param)
253  : Filler<Dtype>(param) {}
254  virtual void Fill(Blob<Dtype>* blob) {
255  CHECK_EQ(blob->num_axes(), 4) << "Blob must be 4 dim.";
256  CHECK_EQ(blob->width(), blob->height()) << "Filter must be square";
257  Dtype* data = blob->mutable_cpu_data();
258  int f = ceil(blob->width() / 2.);
259  Dtype c = (blob->width() - 1) / (2. * f);
260  for (int i = 0; i < blob->count(); ++i) {
261  Dtype x = i % blob->width();
262  Dtype y = (i / blob->width()) % blob->height();
263  data[i] = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
264  }
265  CHECK_EQ(this->filler_param_.sparse(), -1)
266  << "Sparsity not supported by this Filler.";
267  }
268 };
269 
276 template <typename Dtype>
277 Filler<Dtype>* GetFiller(const FillerParameter& param) {
278  const std::string& type = param.type();
279  if (type == "constant") {
280  return new ConstantFiller<Dtype>(param);
281  } else if (type == "gaussian") {
282  return new GaussianFiller<Dtype>(param);
283  } else if (type == "positive_unitball") {
284  return new PositiveUnitballFiller<Dtype>(param);
285  } else if (type == "uniform") {
286  return new UniformFiller<Dtype>(param);
287  } else if (type == "xavier") {
288  return new XavierFiller<Dtype>(param);
289  } else if (type == "msra") {
290  return new MSRAFiller<Dtype>(param);
291  } else if (type == "bilinear") {
292  return new BilinearFiller<Dtype>(param);
293  } else {
294  CHECK(false) << "Unknown filler name: " << param.type();
295  }
296  return (Filler<Dtype>*)(NULL);
297 }
298 
299 } // namespace caffe
300 
301 #endif // CAFFE_FILLER_HPP_
caffe::SyncedMemory
Manages memory allocation and synchronization between the host (CPU) and device (GPU).
Definition: syncedmem.hpp:57
caffe::XavierFiller
Fills a Blob with values where is set inversely proportional to number of incoming nodes,...
Definition: filler.hpp:144
caffe::MSRAFiller
Fills a Blob with values where is set inversely proportional to number of incoming nodes,...
Definition: filler.hpp:189
caffe::GaussianFiller
Fills a Blob with Gaussian-distributed values .
Definition: filler.hpp:65
caffe::Blob
A wrapper around SyncedMemory holders serving as the basic computational unit through which Layers,...
Definition: blob.hpp:24
caffe::Blob::width
int width() const
Deprecated legacy shape accessor width: use shape(3) instead.
Definition: blob.hpp:138
caffe::UniformFiller
Fills a Blob with uniformly distributed values .
Definition: filler.hpp:50
caffe::GetFiller
Filler< Dtype > * GetFiller(const FillerParameter &param)
Get a specific filler from the specification given in FillerParameter.
Definition: filler.hpp:277
caffe::Blob::height
int height() const
Deprecated legacy shape accessor height: use shape(2) instead.
Definition: blob.hpp:136
caffe::Filler
Fills a Blob with constant or randomly-generated data.
Definition: filler.hpp:19
caffe
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14
caffe::BilinearFiller
Fills a Blob with coefficients for bilinear interpolation.
Definition: filler.hpp:250
caffe::ConstantFiller
Fills a Blob with constant values .
Definition: filler.hpp:31
caffe::PositiveUnitballFiller
Fills a Blob with values such that .
Definition: filler.hpp:101