Caffe的solver.hpp解读
solver.hpp
class Net:
public:
- 
    
Solver的构造和析构函数
- explicit Solver( const SolverParameter& param, const Solver* root_solver=NULL)
 - explicit Solver( const string& param_file, const Solver* root_solver=NULL)
 
 - 
    
初始化
- void Init( const SolverParameter& param)
 - void InitTrainNet()
 - void InitTestNets()
 
 - 
    
void SetActionFunction( ActionCallback func); SolverAction::Enum GetRequestedAction() Solver的客户端调用这个来设置solver所使用的function,决定执行哪个步骤(snapshot或退出训练)
 - 
    
Solve()
- virtual void Solve( const char* resume_file=NULL) solver函数的主要接口,默认情况下,iter为0。非零iter次数来继续训练已训练的网络
 - inline void Solver( const string resume_file) {}
 
 - 
    
void Step( int iters)
 - 
    
void Restore( const char* resume_file) Restore方法分派为一个RestoreSolverStateFrom。实现这些方法来保存适当的snapshot类型的状态
 - 
    
virtual ~Solver() {}
 - 
    
Solver相关参数
- inline const SolverParameter& param() const{}
 - inline shared_ptr< Net< Dtype» net() {}
 - inline const vector< shared_ptr< Net< Dtype»>& test_nets() {}
 - int iter() {}
 
 
class Callback
protected:
- 
    
virtual void on_start()=0;
 - 
    
virtual void on_gradients_ready()=0
 - 
    
const vector< Callback*>& callback() const {}
 - 
    
void add_callback( Callback* value)
 
protected:
- 
    
virtual void ApplyUpdate()=0 对当前迭代,创建和使用更新值
 - 
    
Snapshot()
- void Snapshot() Solver::Snapshot函数使用基础的snapshotting工具来存储所学习的网络;且使用SnapshotSolverState()来生成SolverState protocal buffer,用来与所学习网络一同写进硬盘中
 - string SnapshotFilename( const string extension)
 - string SnapshotToBinaryProto()
 - string SnapshotToHDF5()
 
 - 
    
Test流程
- void TestAll()
 - void Test( const int test_net_id=0)
 - virtual void SnapshotSolverState( const string& model_filename) = 0
 - virtual void RestoreSolverStateFromHDF5( const string& state_file) = 0
 - virtual void RestoreSolverFromBinaryProto( const string& state_file) = 0
 - void DisplayOutputBlobs( const int net_id)
 
 - 
    
Solver基本参数(成员数据)
- SolverParameter param_
 - int iter_
 - int current_step_
 - shared_ptr< Net< Dtype» net_
 - vector< share_ptr< Net< Dtype»> test_nets_
 - vector< Callback*> callback_
 
 - 
    
const Solver* const root_solver_ 在数据并行处理中,root solver保存有root net(实际上包含shared layer)
 - 
    
ActionCallback atction_request_function_ Solver客户端所设置的函数,来表示需要提前退出保存snapshot
 - 
    
bool requested_early_exit_ 如果提前停止被接受,则返回true
 
class WorkerSolver: public Solver< Dtype>
public:
- explicit WorkerSolver( const SolverParameter& param, const Solver< Dtype>* root_solver=NULL): Solver< Dtype>(param, root_solver){}
 
protected:
- 
    
void ApplyUpdate(){}
 - 
    
void SnapshotSolverState( const string& model_filename) {}
 - 
    
void RestoreSolverStateFromBinaryProto( const string& state_file) {}
 - 
    
void RestoreSolverStateFromHDF5( const string& state_file)
 
class SGDSolver: public Solver
public:
- 
    
explicit SGDSolver( const SolverParameter& param): Solver< Dtype>( param) {}
 - 
    
explicit SGDSolver( const string& param_file): Solver< Dtype>( para_file){}
 - 
    
const vector< shared_ptr< Blob< Dtype»>& history() {}
 
protected:
- 
    
void PreSolve()
 - 
    
Dtype GetLearningRate()
 - 
    
virtual void ApplyUpdate()
 - 
    
virtual void Normalize( int param_id)
 - 
    
virtual void Regularize( int param_id)
 - 
    
virtual void ComputerUpdateValue( int param_id, Dtype rate)
 - 
    
virtual void ClipGradients()
 - 
    
vector< shared_ptr< Blob< Dtype»> history_, update_, temp_ history保存历史的momentum数据;update保存更新相关数据,对snapshot不需要;temp保存在计算更新值或梯度时,需要的其他信息,在snapshot中不需要