Caffe
All Classes Namespaces Functions Variables Typedefs
sgd_solvers.hpp
1 #ifndef CAFFE_SGD_SOLVERS_HPP_
2 #define CAFFE_SGD_SOLVERS_HPP_
3 
4 #include <string>
5 #include <vector>
6 
7 #include "caffe/solver.hpp"
8 
9 namespace caffe {
10 
15 template <typename Dtype>
16 class SGDSolver : public Solver<Dtype> {
17  public:
18  explicit SGDSolver(const SolverParameter& param)
19  : Solver<Dtype>(param) { PreSolve(); }
20  explicit SGDSolver(const string& param_file)
21  : Solver<Dtype>(param_file) { PreSolve(); }
22  virtual inline const char* type() const { return "SGD"; }
23 
24  const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
25 
26  virtual void ApplyUpdate();
27  Dtype GetLearningRate();
28 
29  protected:
30  void PreSolve();
31  virtual void Normalize(int param_id);
32  virtual void Regularize(int param_id);
33  virtual void ComputeUpdateValue(int param_id, Dtype rate);
34  virtual void ClipGradients();
35  virtual void SnapshotSolverState(const string& model_filename);
36  virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
37  virtual void SnapshotSolverStateToHDF5(const string& model_filename);
38  virtual void RestoreSolverStateFromHDF5(const string& state_file);
39  virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
40  // history maintains the historical momentum data.
41  // update maintains update related data and is not needed in snapshots.
42  // temp maintains other information that might be needed in computation
43  // of gradients/updates and is not needed in snapshots
44  vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;
45 
46  DISABLE_COPY_AND_ASSIGN(SGDSolver);
47 };
48 
49 template <typename Dtype>
50 class NesterovSolver : public SGDSolver<Dtype> {
51  public:
52  explicit NesterovSolver(const SolverParameter& param)
53  : SGDSolver<Dtype>(param) {}
54  explicit NesterovSolver(const string& param_file)
55  : SGDSolver<Dtype>(param_file) {}
56  virtual inline const char* type() const { return "Nesterov"; }
57 
58  protected:
59  virtual void ComputeUpdateValue(int param_id, Dtype rate);
60 
61  DISABLE_COPY_AND_ASSIGN(NesterovSolver);
62 };
63 
64 template <typename Dtype>
65 class AdaGradSolver : public SGDSolver<Dtype> {
66  public:
67  explicit AdaGradSolver(const SolverParameter& param)
68  : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
69  explicit AdaGradSolver(const string& param_file)
70  : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
71  virtual inline const char* type() const { return "AdaGrad"; }
72 
73  protected:
74  virtual void ComputeUpdateValue(int param_id, Dtype rate);
75  void constructor_sanity_check() {
76  CHECK_EQ(0, this->param_.momentum())
77  << "Momentum cannot be used with AdaGrad.";
78  }
79 
80  DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
81 };
82 
83 
84 template <typename Dtype>
85 class RMSPropSolver : public SGDSolver<Dtype> {
86  public:
87  explicit RMSPropSolver(const SolverParameter& param)
88  : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
89  explicit RMSPropSolver(const string& param_file)
90  : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
91  virtual inline const char* type() const { return "RMSProp"; }
92 
93  protected:
94  virtual void ComputeUpdateValue(int param_id, Dtype rate);
95  void constructor_sanity_check() {
96  CHECK_EQ(0, this->param_.momentum())
97  << "Momentum cannot be used with RMSProp.";
98  CHECK_GE(this->param_.rms_decay(), 0)
99  << "rms_decay should lie between 0 and 1.";
100  CHECK_LT(this->param_.rms_decay(), 1)
101  << "rms_decay should lie between 0 and 1.";
102  }
103 
104  DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
105 };
106 
107 template <typename Dtype>
108 class AdaDeltaSolver : public SGDSolver<Dtype> {
109  public:
110  explicit AdaDeltaSolver(const SolverParameter& param)
111  : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
112  explicit AdaDeltaSolver(const string& param_file)
113  : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
114  virtual inline const char* type() const { return "AdaDelta"; }
115 
116  protected:
117  void AdaDeltaPreSolve();
118  virtual void ComputeUpdateValue(int param_id, Dtype rate);
119 
120  DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
121 };
122 
131 template <typename Dtype>
132 class AdamSolver : public SGDSolver<Dtype> {
133  public:
134  explicit AdamSolver(const SolverParameter& param)
135  : SGDSolver<Dtype>(param) { AdamPreSolve();}
136  explicit AdamSolver(const string& param_file)
137  : SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
138  virtual inline const char* type() const { return "Adam"; }
139 
140  protected:
141  void AdamPreSolve();
142  virtual void ComputeUpdateValue(int param_id, Dtype rate);
143 
144  DISABLE_COPY_AND_ASSIGN(AdamSolver);
145 };
146 
147 } // namespace caffe
148 
149 #endif // CAFFE_SGD_SOLVERS_HPP_
caffe::AdaGradSolver
Definition: sgd_solvers.hpp:65
caffe::NesterovSolver
Definition: sgd_solvers.hpp:50
caffe::NesterovSolver::type
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:56
caffe::RMSPropSolver
Definition: sgd_solvers.hpp:85
caffe::AdaGradSolver::type
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:71
caffe::SGDSolver
Optimizes the parameters of a Net using stochastic gradient descent (SGD) with momentum.
Definition: sgd_solvers.hpp:16
caffe::AdamSolver
AdamSolver, an algorithm for first-order gradient-based optimization of stochastic objective function...
Definition: sgd_solvers.hpp:132
caffe::Solver
An interface for classes that perform optimization on Nets.
Definition: solver.hpp:42
caffe::SGDSolver::type
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:22
caffe::AdamSolver::type
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:138
caffe::AdaDeltaSolver::type
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:114
caffe::AdaDeltaSolver
Definition: sgd_solvers.hpp:108
caffe
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14
caffe::RMSPropSolver::type
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:91