solver.hpp

class Net:

public:

  • Solver的构造和析构函数

    • explicit Solver( const SolverParameter& param, const Solver* root_solver=NULL)
    • explicit Solver( const string& param_file, const Solver* root_solver=NULL)
  • 初始化

    • void Init( const SolverParameter& param)
    • void InitTrainNet()
    • void InitTestNets()
  • void SetActionFunction( ActionCallback func); SolverAction::Enum GetRequestedAction() Solver的客户端调用这个来设置solver所使用的function,决定执行哪个步骤(snapshot或退出训练)

  • Solve()

    • virtual void Solve( const char* resume_file=NULL) solver函数的主要接口,默认情况下,iter为0。非零iter次数来继续训练已训练的网络
    • inline void Solver( const string resume_file) {}
  • void Step( int iters)

  • void Restore( const char* resume_file) Restore方法分派为一个RestoreSolverStateFrom。实现这些方法来保存适当的snapshot类型的状态

  • virtual ~Solver() {}

  • Solver相关参数

    • inline const SolverParameter& param() const{}
    • inline shared_ptr< Net< Dtype» net() {}
    • inline const vector< shared_ptr< Net< Dtype»>& test_nets() {}
    • int iter() {}

class Callback

protected:

  • virtual void on_start()=0;

  • virtual void on_gradients_ready()=0

  • const vector< Callback*>& callback() const {}

  • void add_callback( Callback* value)

protected:

  • virtual void ApplyUpdate()=0 对当前迭代,创建和使用更新值

  • Snapshot()

    • void Snapshot() Solver::Snapshot函数使用基础的snapshotting工具来存储所学习的网络;且使用SnapshotSolverState()来生成SolverState protocal buffer,用来与所学习网络一同写进硬盘中
    • string SnapshotFilename( const string extension)
    • string SnapshotToBinaryProto()
    • string SnapshotToHDF5()
  • Test流程

    • void TestAll()
    • void Test( const int test_net_id=0)
    • virtual void SnapshotSolverState( const string& model_filename) = 0
    • virtual void RestoreSolverStateFromHDF5( const string& state_file) = 0
    • virtual void RestoreSolverFromBinaryProto( const string& state_file) = 0
    • void DisplayOutputBlobs( const int net_id)
  • Solver基本参数(成员数据)

    • SolverParameter param_
    • int iter_
    • int current_step_
    • shared_ptr< Net< Dtype» net_
    • vector< share_ptr< Net< Dtype»> test_nets_
    • vector< Callback*> callback_
  • const Solver* const root_solver_ 在数据并行处理中,root solver保存有root net(实际上包含shared layer)

  • ActionCallback atction_request_function_ Solver客户端所设置的函数,来表示需要提前退出保存snapshot

  • bool requested_early_exit_ 如果提前停止被接受,则返回true

class WorkerSolver: public Solver< Dtype>

public:

  • explicit WorkerSolver( const SolverParameter& param, const Solver< Dtype>* root_solver=NULL): Solver< Dtype>(param, root_solver){}

protected:

  • void ApplyUpdate(){}

  • void SnapshotSolverState( const string& model_filename) {}

  • void RestoreSolverStateFromBinaryProto( const string& state_file) {}

  • void RestoreSolverStateFromHDF5( const string& state_file)

class SGDSolver: public Solver

public:

  • explicit SGDSolver( const SolverParameter& param): Solver< Dtype>( param) {}

  • explicit SGDSolver( const string& param_file): Solver< Dtype>( para_file){}

  • const vector< shared_ptr< Blob< Dtype»>& history() {}

protected:

  • void PreSolve()

  • Dtype GetLearningRate()

  • virtual void ApplyUpdate()

  • virtual void Normalize( int param_id)

  • virtual void Regularize( int param_id)

  • virtual void ComputerUpdateValue( int param_id, Dtype rate)

  • virtual void ClipGradients()

  • vector< shared_ptr< Blob< Dtype»> history_, update_, temp_ history保存历史的momentum数据;update保存更新相关数据,对snapshot不需要;temp保存在计算更新值或梯度时,需要的其他信息,在snapshot中不需要

class NesterovSolver: public SGDSolver< Dtype>

class AdaGradSolver: public SGDSolver< Dtype>

class RMSPropSolver: public SGDSolver< Dtype>