38 #ifndef CAFFE_SOLVER_FACTORY_H_
39 #define CAFFE_SOLVER_FACTORY_H_
45 #include "caffe/common.hpp"
46 #include "caffe/proto/caffe.pb.h"
50 template <
typename Dtype>
53 template <
typename Dtype>
57 typedef std::map<string, Creator> CreatorRegistry;
59 static CreatorRegistry& Registry() {
60 static CreatorRegistry* g_registry_ =
new CreatorRegistry();
65 static void AddCreator(
const string& type, Creator creator) {
66 CreatorRegistry& registry = Registry();
67 CHECK_EQ(registry.count(type), 0)
68 <<
"Solver type " << type <<
" already registered.";
69 registry[type] = creator;
73 static Solver<Dtype>* CreateSolver(
const SolverParameter& param) {
74 const string& type = param.
type();
75 CreatorRegistry& registry = Registry();
76 CHECK_EQ(registry.count(type), 1) <<
"Unknown solver type: " << type
77 <<
" (known types: " << SolverTypeListString() <<
")";
78 return registry[type](param);
81 static vector<string> SolverTypeList() {
82 CreatorRegistry& registry = Registry();
83 vector<string> solver_types;
84 for (
typename CreatorRegistry::iterator iter = registry.begin();
85 iter != registry.end(); ++iter) {
86 solver_types.push_back(iter->first);
96 static string SolverTypeListString() {
97 vector<string> solver_types = SolverTypeList();
98 string solver_types_str;
99 for (vector<string>::iterator iter = solver_types.begin();
100 iter != solver_types.end(); ++iter) {
101 if (iter != solver_types.begin()) {
102 solver_types_str +=
", ";
104 solver_types_str += *iter;
106 return solver_types_str;
111 template <
typename Dtype>
122 #define REGISTER_SOLVER_CREATOR(type, creator) \
123 static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
124 static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
126 #define REGISTER_SOLVER_CLASS(type) \
127 template <typename Dtype> \
128 Solver<Dtype>* Creator_##type##Solver( \
129 const SolverParameter& param) \
131 return new type##Solver<Dtype>(param); \
133 REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
137 #endif // CAFFE_SOLVER_FACTORY_H_