Caffe的solver.hpp解读
solver.hpp
class Net:
public:
-
Solver的构造和析构函数
- explicit Solver( const SolverParameter& param, const Solver* root_solver=NULL)
- explicit Solver( const string& param_file, const Solver* root_solver=NULL)
-
初始化
- void Init( const SolverParameter& param)
- void InitTrainNet()
- void InitTestNets()
-
void SetActionFunction( ActionCallback func); SolverAction::Enum GetRequestedAction() Solver的客户端调用这个来设置solver所使用的function,决定执行哪个步骤(snapshot或退出训练)
-
Solve()
- virtual void Solve( const char* resume_file=NULL) solver函数的主要接口,默认情况下,iter为0。非零iter次数来继续训练已训练的网络
- inline void Solver( const string resume_file) {}
-
void Step( int iters)
-
void Restore( const char* resume_file) Restore方法分派为一个RestoreSolverStateFrom。实现这些方法来保存适当的snapshot类型的状态
-
virtual ~Solver() {}
-
Solver相关参数
- inline const SolverParameter& param() const{}
- inline shared_ptr< Net< Dtype» net() {}
- inline const vector< shared_ptr< Net< Dtype»>& test_nets() {}
- int iter() {}
class Callback
protected:
-
virtual void on_start()=0;
-
virtual void on_gradients_ready()=0
-
const vector< Callback*>& callback() const {}
-
void add_callback( Callback* value)
protected:
-
virtual void ApplyUpdate()=0 对当前迭代,创建和使用更新值
-
Snapshot()
- void Snapshot() Solver::Snapshot函数使用基础的snapshotting工具来存储所学习的网络;且使用SnapshotSolverState()来生成SolverState protocal buffer,用来与所学习网络一同写进硬盘中
- string SnapshotFilename( const string extension)
- string SnapshotToBinaryProto()
- string SnapshotToHDF5()
-
Test流程
- void TestAll()
- void Test( const int test_net_id=0)
- virtual void SnapshotSolverState( const string& model_filename) = 0
- virtual void RestoreSolverStateFromHDF5( const string& state_file) = 0
- virtual void RestoreSolverFromBinaryProto( const string& state_file) = 0
- void DisplayOutputBlobs( const int net_id)
-
Solver基本参数(成员数据)
- SolverParameter param_
- int iter_
- int current_step_
- shared_ptr< Net< Dtype» net_
- vector< share_ptr< Net< Dtype»> test_nets_
- vector< Callback*> callback_
-
const Solver* const root_solver_ 在数据并行处理中,root solver保存有root net(实际上包含shared layer)
-
ActionCallback atction_request_function_ Solver客户端所设置的函数,来表示需要提前退出保存snapshot
-
bool requested_early_exit_ 如果提前停止被接受,则返回true
class WorkerSolver: public Solver< Dtype>
public:
- explicit WorkerSolver( const SolverParameter& param, const Solver< Dtype>* root_solver=NULL): Solver< Dtype>(param, root_solver){}
protected:
-
void ApplyUpdate(){}
-
void SnapshotSolverState( const string& model_filename) {}
-
void RestoreSolverStateFromBinaryProto( const string& state_file) {}
-
void RestoreSolverStateFromHDF5( const string& state_file)
class SGDSolver: public Solver
public:
-
explicit SGDSolver( const SolverParameter& param): Solver< Dtype>( param) {}
-
explicit SGDSolver( const string& param_file): Solver< Dtype>( para_file){}
-
const vector< shared_ptr< Blob< Dtype»>& history() {}
protected:
-
void PreSolve()
-
Dtype GetLearningRate()
-
virtual void ApplyUpdate()
-
virtual void Normalize( int param_id)
-
virtual void Regularize( int param_id)
-
virtual void ComputerUpdateValue( int param_id, Dtype rate)
-
virtual void ClipGradients()
-
vector< shared_ptr< Blob< Dtype»> history_, update_, temp_ history保存历史的momentum数据;update保存更新相关数据,对snapshot不需要;temp保存在计算更新值或梯度时,需要的其他信息,在snapshot中不需要