1 #ifndef CAFFE_SOLVER_HPP_
2 #define CAFFE_SOLVER_HPP_
3 #include <boost/function.hpp>
7 #include "caffe/net.hpp"
8 #include "caffe/solver_factory.hpp"
9 #include "caffe/util/benchmark.hpp"
21 namespace SolverAction {
41 template <
typename Dtype>
44 explicit Solver(
const SolverParameter& param);
45 explicit Solver(
const string& param_file);
46 void Init(
const SolverParameter& param);
54 SolverAction::Enum GetRequestedAction();
57 virtual void Solve(
const char* resume_file = NULL);
58 inline void Solve(
const string& resume_file) { Solve(resume_file.c_str()); }
63 void Restore(
const char* resume_file);
70 inline const SolverParameter& param()
const {
return param_; }
71 inline shared_ptr<Net<Dtype> > net() {
return net_; }
72 inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
75 int iter()
const {
return iter_; }
80 virtual void on_start() = 0;
81 virtual void on_gradients_ready() = 0;
86 const vector<Callback*>& callbacks()
const {
return callbacks_; }
87 void add_callback(Callback* value) {
88 callbacks_.push_back(value);
91 void CheckSnapshotWritePermissions();
95 virtual inline const char*
type()
const {
return ""; }
98 virtual void ApplyUpdate() = 0;
101 string SnapshotFilename(
const string& extension);
102 string SnapshotToBinaryProto();
103 string SnapshotToHDF5();
106 void Test(
const int test_net_id = 0);
107 virtual void SnapshotSolverState(
const string& model_filename) = 0;
108 virtual void RestoreSolverStateFromHDF5(
const string& state_file) = 0;
109 virtual void RestoreSolverStateFromBinaryProto(
const string& state_file) = 0;
110 void DisplayOutputBlobs(
const int net_id);
111 void UpdateSmoothedLoss(Dtype loss,
int start_iter,
int average_loss);
113 SolverParameter param_;
116 shared_ptr<Net<Dtype> > net_;
117 vector<shared_ptr<Net<Dtype> > > test_nets_;
118 vector<Callback*> callbacks_;
119 vector<Dtype> losses_;
120 Dtype smoothed_loss_;
127 bool requested_early_exit_;
130 Timer iteration_timer_;
131 float iterations_last_;
133 DISABLE_COPY_AND_ASSIGN(
Solver);
138 #endif // CAFFE_SOLVER_HPP_