概览
这篇博客解析caffe函数入口caffe.cpp,主要内容为caffe启动框架,基本不涉及深度学习的具体内容,内容十分基础,适合新手阅读。下面所有的代码解析都以训练lenet手写数字体识别为例,其运行参数为:
caffe train --solver=examples/mnist/lenet_solver.prototxt $@
main函数
先把main函数贴上来
int main(int argc, char** argv) { // Print output to stderr (while still logging). FLAGS_alsologtostderr = 1; // Set version gflags::SetVersionString(AS_STRING(CAFFE_VERSION)); // Usage message. gflags::SetUsageMessage("command line brew\n" "usage: caffe <command> <args>\n\n" "commands:\n" " train train or finetune a model\n" " test score a model\n" " device_query show GPU diagnostic information\n" " time benchmark model execution time"); // Run tool or show usage. caffe::GlobalInit(&argc, &argv); if (argc == 2) { #ifdef WITH_PYTHON_LAYER try { #endif return GetBrewFunction(caffe::string(argv[1]))(); #ifdef WITH_PYTHON_LAYER } catch (bp::error_already_set) { PyErr_Print(); return 1; } #endif } else { gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe"); } }
main函数上来就是一个变量FLAGS_alsologtostderr,但vscode找不到该变量的定义。其实这个变量包括其他带有FLAGS前缀的变量是由gflags定义的,gflags 是 google 开源的用于处理命令行参数的项目。alsologtostderr指将日志输出到标准错误流中去。后面SetVersionString 的作用是当你使用caffe --version时能打印出caffe的版本信息,CAFFE_VERSION由Makefile指定.紧接着SetUsageMessage实际上设置了caffe的帮助信息,当运行caffe参数不正确或者使用--help参数时打印出usage信息。caffe::GlobalInit函数会根据命令行参数做一些初始化的工作,其定义在common.cpp中,具体如下:
void GlobalInit(int* pargc, char*** pargv) { // Google flags. ::gflags::ParseCommandLineFlags(pargc, pargv, true); // Google logging. ::google::InitGoogleLogging(*(pargv)[0]); // Provide a backtrace on segfault. ::google::InstallFailureSignalHandler(); }
对于训练手写数字体识别:
只有一个参数solver =examples/mnist/lenet_solver.prototxt 解析后可以以FLAGS_solver来访问。包括solver model等用户自定义的命令行参数(非gflags默认的参数)定义在caffe.cpp里:
DEFINE_string(gpu, "", "Optional; run in GPU mode on given device IDs separated by ','." "Use '-gpu all' to run on all available GPUs. The effective training " "batch size is multiplied by the number of devices."); DEFINE_string(solver, "", "The solver definition protocol buffer text file."); DEFINE_string(model, "", "The model definition protocol buffer text file.");
对于gflags更详细的信息可以参考google gflags 库完全使用
后面的InitGoogleLogging和InstallFailureSignalHandler用来处理日志和运行错误。
那么main函数怎么根据train test等参数进入到相应的train函数或test函数中去呢?
看这一行代码:
return GetBrewFunction(caffe::string(argv[1]))();
这个函数可以根据第一个参数argv[1](argv[0]是caffe本身的路径)来返回相应的函数,接下来我们来看GetBrewFunction是怎么实现这个功能的。
typedef int (*BrewFunction)(); //定义了一个函数指针类型,该类型指针指向一个参数为空返回值为int的函数 typedef std::map<caffe::string, BrewFunction> BrewMap;//定义了一个map类型,该类型的变量维护一个字典,函数名称(string)作为key,函数指针(BrewFunction)作为value BrewMap g_brew_map; #define RegisterBrewFunction(func) \ namespace { \ class __Registerer_##func { \ //##表示合并字符串 public: /* NOLINT */ \ __Registerer_##func() { \ g_brew_map[#func] = &func; \ #为字符串 } \ }; \ __Registerer_##func g_registerer_##func; \ } static BrewFunction GetBrewFunction(const caffe::string& name) { if (g_brew_map.count(name)) { return g_brew_map[name];//根据name中的具体内容返回相应的函数指针 } else { LOG(ERROR) << "Available caffe actions:"; for (BrewMap::iterator it = g_brew_map.begin(); it != g_brew_map.end(); ++it) { LOG(ERROR) << "\t" << it->first; } LOG(FATAL) << "Unknown action: " << name; return NULL; // not reachable, just to suppress old compiler warnings. } }
//下面是一个例子,详细说明train函数怎么填充到g_brew_map中 int train(){ } RegisterBrewFunction(train)//这一句会根据宏定义被替换成下面的内容 namespace{ class __Registerer_train{ public: __Registerer_train(){ g_brew_map["train"] = &train; } }; __Registerer_train g_registerer_train; //实例化的过程中将train函数填充到字典g_brew_map中去了 }
根据上面一些注释,我们可以看出一个大概的框架:
1 定义一个字典,存储函数名到函数指针的映射。
2 通过RegisterBrewFunction(func)的宏定义来填充这个字典。
3 调用GetBrewFunction根据函数名返回相应的函数指针。
train函数
下面具体看train函数
// Train / Finetune a model. int train() { CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train."; //FLAGS_solver <= 0 会输出 CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size())// snapshot 和 weight参数都没有,不管 << "Give a snapshot to resume training or weights to finetune " "but not both."; vector<string> stages = get_stages_from_flags(); //stages参数也没有,跳过 caffe::SolverParameter solver_param; caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//该行从lenet_solver.prototxt读取参数到solver_param中 solver_param.mutable_train_state()->set_level(FLAGS_level); //level参数也没有,跳过 for (int i = 0; i < stages.size(); i++) { solver_param.mutable_train_state()->add_stage(stages[i]); } // If the gpus flag is not provided, allow the mode and device to be set // in the solver prototxt. if (FLAGS_gpu.size() == 0 //从solverparam中读取GPU的信息,是否使用GPU,GPU的id之类的,初期可以不用特别关注 && solver_param.has_solver_mode() && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) { if (solver_param.has_device_id()) { FLAGS_gpu = "" + boost::lexical_cast<string>(solver_param.device_id()); } else { // Set default GPU if unspecified FLAGS_gpu = "" + boost::lexical_cast<string>(0); } } vector<int> gpus; get_gpus(&gpus); if (gpus.size() == 0) { LOG(INFO) << "Use CPU."; Caffe::set_mode(Caffe::CPU); } else { ostringstream s; for (int i = 0; i < gpus.size(); ++i) { s << (i ? ", " : "") << gpus[i]; } LOG(INFO) << "Using GPUs " << s.str(); #ifndef CPU_ONLY cudaDeviceProp device_prop; for (int i = 0; i < gpus.size(); ++i) { cudaGetDeviceProperties(&device_prop, gpus[i]); LOG(INFO) << "GPU " << gpus[i] << ": " << device_prop.name; } #endif solver_param.set_device_id(gpus[0]); Caffe::SetDevice(gpus[0]); Caffe::set_mode(Caffe::GPU); Caffe::set_solver_count(gpus.size()); } caffe::SignalHandler signal_handler( GetRequestedAction(FLAGS_sigint_effect), GetRequestedAction(FLAGS_sighup_effect)); if (FLAGS_snapshot.size()) { solver_param.clear_weights(); } else if (FLAGS_weights.size()) { solver_param.clear_weights(); solver_param.add_weights(FLAGS_weights); } //根据solver_param,生成solver shared_ptr<caffe::Solver<float> > solver(caffe::SolverRegistry<float>::CreateSolver(solver_param)); solver->SetActionFunction(signal_handler.GetActionFunction()); if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot; solver->Restore(FLAGS_snapshot.c_str()); } LOG(INFO) << "Starting Optimization"; if (gpus.size() > 1) { #ifdef USE_NCCL caffe::NCCL<float> nccl(solver); nccl.Run(gpus, FLAGS_snapshot.size() > 0 ? FLAGS_snapshot.c_str() : NULL); #else LOG(FATAL) << "Multi-GPU execution not available - rebuild with USE_NCCL"; #endif } else { //求解solver solver->Solve(); } LOG(INFO) << "Optimization Done."; return 0; }
solver的实例化
这里不涉及任何solver内部的细节,包括生成_net和test_net,具体的求解方法等内容,只剖析caffe怎样根据solverparam.type实例化不同的solver类。实际上这些内容和上面讲的根据命令行参数执行train还是test等函数的方法十分相似,但其过程更加复杂,还是简要的分析一下。
shared_ptr<caffe::Solver<float>> solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
caffe.cpp中的train函数中通过上述的代码定义了一个指向Solver<float>
的shared_ptr。其中主要是通过调用SolverRegistry这个类的静态成员函数CreateSolver得到一个指向Solver的指针来构造shared_ptr类型的solver。而且由于C++多态的特性,solver是一个指向基类Solver类型的指针,通过solver这个智能指针来调用各个成员函数会调用到各个子类(SGDSolver等)的函数。
下面分析SolverRegistry具体是怎么做的:
typedef Solver<Dtype>* (*Creator)(const SolverParameter&); typedef std::map<string, Creator> CreatorRegistry; static CreatorRegistry& Registry() { static CreatorRegistry* g_registry_ = new CreatorRegistry(); return *g_registry_; }
static Solver<Dtype>* CreateSolver(const SolverParameter& param) { const string& type = param.type(); CreatorRegistry& registry = Registry(); CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type << " (known types: " << SolverTypeListString() << ")"; return registry[type](param); }
从上述代码可以看到也是维护了一个map由solverparam.type返回具体的solver<Dtype>指针
SolverRegistry类的构造函数是private的,也就是用我们没有办法去构造一个这个类的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。 CreateSolver函数先定义了string类型的变量type,表示Solver的类型,然后定义了一个key类型为string,value类型为Creator的map,变量名为registry,其中Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver<Dtype>*
。如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver<Dtype>*
返回。
Registry函数中定义了一个static的变量g_registry,这个变量是一个指向CreatorRegistry这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,可以在其他地方修改这个map里的内容,。事实上各个Solver的register的过程正是向g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。
那包括SGDSolver等各种solver是怎么注册的呢?下面以注册SGDSolver为例说明
solver_factory.hpp文件中有两个宏定义如下:
#define REGISTER_SOLVER_CREATOR(type, creator) \ static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \ static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \ #define REGISTER_SOLVER_CLASS(type) \ template <typename Dtype> \ Solver<Dtype>* Creator_##type##Solver( \ const SolverParameter& param) \ { \ return new type##Solver<Dtype>(param); \ } \ REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
sgd_solver.cpp文件末尾有
REGISTER_SOLVER_CLASS(SGD);
根据宏定义替换的结果如下:
template <typename Dtype> Solver<Dtype>* Creator_SGD_Solver(const SolverParameter& param) { return new SGDSolver<Dtype>(param); } static SolverRegisterer<float> g_creator_f_SGD("SGD",Creator_SGD_Solver<float>); static SolverRegisterer<double> g_creator_f_SGD("SGD",Creator_SGD_Solver<double>);
即根据宏定义,定义了一个Creator函数指针可指的函数Creator_SGD_Solver,然后通过下面的函数将key和value注册进去:
template <typename Dtype> class SolverRegisterer { public: SolverRegisterer(const string& type, Solver<Dtype>* (*creator)(const SolverParameter&)) { // LOG(INFO) << "Registering solver type: " << type; SolverRegistry<Dtype>::AddCreator(type, creator); } };
AddCreator函数的源码不在此展示,具体细节阅读solver_factory.hpp
至此,生成solver的工厂模式应该讲清楚了,caffe的启动框架也差不多清晰了,接下来就是solver怎么根据solver_params生成net,以及net的前向和反向计算了。
参考:
共同学习,写下你的评论
评论加载中...
作者其他优质文章