1 #ifndef CAFFE_SGD_SOLVERS_HPP_
2 #define CAFFE_SGD_SOLVERS_HPP_
7 #include "caffe/solver.hpp"
15 template <
typename Dtype>
18 explicit SGDSolver(
const SolverParameter& param)
20 explicit SGDSolver(
const string& param_file)
22 virtual inline const char*
type()
const {
return "SGD"; }
24 const vector<shared_ptr<Blob<Dtype> > >& history() {
return history_; }
26 virtual void ApplyUpdate();
27 Dtype GetLearningRate();
31 virtual void Normalize(
int param_id);
32 virtual void Regularize(
int param_id);
33 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
34 virtual void ClipGradients();
35 virtual void SnapshotSolverState(
const string& model_filename);
36 virtual void SnapshotSolverStateToBinaryProto(
const string& model_filename);
37 virtual void SnapshotSolverStateToHDF5(
const string& model_filename);
38 virtual void RestoreSolverStateFromHDF5(
const string& state_file);
39 virtual void RestoreSolverStateFromBinaryProto(
const string& state_file);
44 vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;
46 DISABLE_COPY_AND_ASSIGN(SGDSolver);
49 template <
typename Dtype>
56 virtual inline const char*
type()
const {
return "Nesterov"; }
59 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
64 template <
typename Dtype>
71 virtual inline const char*
type()
const {
return "AdaGrad"; }
74 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
75 void constructor_sanity_check() {
76 CHECK_EQ(0, this->param_.momentum())
77 <<
"Momentum cannot be used with AdaGrad.";
80 DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
84 template <
typename Dtype>
91 virtual inline const char*
type()
const {
return "RMSProp"; }
94 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
95 void constructor_sanity_check() {
96 CHECK_EQ(0, this->param_.momentum())
97 <<
"Momentum cannot be used with RMSProp.";
98 CHECK_GE(this->param_.rms_decay(), 0)
99 <<
"rms_decay should lie between 0 and 1.";
100 CHECK_LT(this->param_.rms_decay(), 1)
101 <<
"rms_decay should lie between 0 and 1.";
104 DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
107 template <
typename Dtype>
114 virtual inline const char*
type()
const {
return "AdaDelta"; }
117 void AdaDeltaPreSolve();
118 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
131 template <
typename Dtype>
134 explicit AdamSolver(
const SolverParameter& param)
138 virtual inline const char*
type()
const {
return "Adam"; }
142 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
149 #endif // CAFFE_SGD_SOLVERS_HPP_