5 #ifndef CAFFE_FILLER_HPP
6 #define CAFFE_FILLER_HPP
10 #include "caffe/blob.hpp"
11 #include "caffe/proto/caffe.pb.h"
12 #include "caffe/syncedmem.hpp"
13 #include "caffe/util/math_functions.hpp"
18 template <
typename Dtype>
21 explicit Filler(
const FillerParameter& param) : filler_param_(param) {}
25 FillerParameter filler_param_;
30 template <
typename Dtype>
36 Dtype* data = blob->mutable_cpu_data();
37 const int count = blob->count();
38 const Dtype value = this->filler_param_.value();
40 for (
int i = 0; i < count; ++i) {
43 CHECK_EQ(this->filler_param_.sparse(), -1)
44 <<
"Sparsity not supported by this Filler.";
49 template <
typename Dtype>
56 caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),
57 Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
58 CHECK_EQ(this->filler_param_.sparse(), -1)
59 <<
"Sparsity not supported by this Filler.";
64 template <
typename Dtype>
70 Dtype* data = blob->mutable_cpu_data();
72 caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),
73 Dtype(this->filler_param_.std()), blob->mutable_cpu_data());
74 int sparse = this->filler_param_.sparse();
81 CHECK_GE(blob->num_axes(), 1);
82 const int num_outputs = blob->shape(0);
83 Dtype non_zero_probability = Dtype(sparse) / Dtype(num_outputs);
84 rand_vec_.reset(
new SyncedMemory(blob->count() *
sizeof(
int)));
85 int* mask =
reinterpret_cast<int*
>(rand_vec_->mutable_cpu_data());
86 caffe_rng_bernoulli(blob->count(), non_zero_probability, mask);
87 for (
int i = 0; i < blob->count(); ++i) {
94 shared_ptr<SyncedMemory> rand_vec_;
100 template <
typename Dtype>
106 Dtype* data = blob->mutable_cpu_data();
107 DCHECK(blob->count());
108 caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
111 int dim = blob->count() / blob->shape(0);
113 for (
int i = 0; i < blob->shape(0); ++i) {
115 for (
int j = 0; j < dim; ++j) {
116 sum += data[i * dim + j];
118 for (
int j = 0; j < dim; ++j) {
119 data[i * dim + j] /= sum;
122 CHECK_EQ(this->filler_param_.sparse(), -1)
123 <<
"Sparsity not supported by this Filler.";
143 template <
typename Dtype>
149 CHECK(blob->count());
150 int fan_in = blob->count() / blob->shape(0);
152 int fan_out = blob->num_axes() > 1 ?
153 blob->count() / blob->shape(1) :
156 if (this->filler_param_.variance_norm() ==
157 FillerParameter_VarianceNorm_AVERAGE) {
158 n = (fan_in + fan_out) / Dtype(2);
159 }
else if (this->filler_param_.variance_norm() ==
160 FillerParameter_VarianceNorm_FAN_OUT) {
163 Dtype scale = sqrt(Dtype(3) / n);
164 caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
165 blob->mutable_cpu_data());
166 CHECK_EQ(this->filler_param_.sparse(), -1)
167 <<
"Sparsity not supported by this Filler.";
188 template <
typename Dtype>
191 explicit MSRAFiller(
const FillerParameter& param)
194 CHECK(blob->count());
195 int fan_in = blob->count() / blob->shape(0);
197 int fan_out = blob->num_axes() > 1 ?
198 blob->count() / blob->shape(1) :
201 if (this->filler_param_.variance_norm() ==
202 FillerParameter_VarianceNorm_AVERAGE) {
203 n = (fan_in + fan_out) / Dtype(2);
204 }
else if (this->filler_param_.variance_norm() ==
205 FillerParameter_VarianceNorm_FAN_OUT) {
208 Dtype std = sqrt(Dtype(2) / n);
209 caffe_rng_gaussian<Dtype>(blob->count(), Dtype(0), std,
210 blob->mutable_cpu_data());
211 CHECK_EQ(this->filler_param_.sparse(), -1)
212 <<
"Sparsity not supported by this Filler.";
249 template <
typename Dtype>
255 CHECK_EQ(blob->num_axes(), 4) <<
"Blob must be 4 dim.";
256 CHECK_EQ(blob->
width(), blob->
height()) <<
"Filter must be square";
257 Dtype* data = blob->mutable_cpu_data();
258 int f = ceil(blob->
width() / 2.);
259 Dtype c = (blob->
width() - 1) / (2. * f);
260 for (
int i = 0; i < blob->count(); ++i) {
261 Dtype x = i % blob->
width();
263 data[i] = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
265 CHECK_EQ(this->filler_param_.sparse(), -1)
266 <<
"Sparsity not supported by this Filler.";
276 template <
typename Dtype>
278 const std::string& type = param.type();
279 if (type ==
"constant") {
281 }
else if (type ==
"gaussian") {
283 }
else if (type ==
"positive_unitball") {
285 }
else if (type ==
"uniform") {
287 }
else if (type ==
"xavier") {
289 }
else if (type ==
"msra") {
291 }
else if (type ==
"bilinear") {
294 CHECK(
false) <<
"Unknown filler name: " << param.type();
301 #endif // CAFFE_FILLER_HPP_