前几天跑了一个mnist的例子,今天开启源码分析之路.不过本系列不会拘泥于每一行代码,只是从总体框架出发,把源码理顺.第一篇博文就是为了理顺程序运行过程.

疑问

众所周知,caffe不需写代码,只要配置几个配置文档就可以运行程序,这对于程序员的我还是挺神奇的.
caffe使用C++写的,可以肯定caffe程序是c++程序,那么那些非C++配置文档又是怎么和C++代码融合的了?

Solver的运行流程

声明:本篇只是在solver层面做整体流程分析,屏蔽Net,Layer,Blob层面的细节.
对于上面的疑问,我们要从配置文档说起.其实它也是一种满足Google Protocol Buffer格式的代码,可以通过相应的工具将其转为C++.我们只不过是以一种简单的方式写代码而已.

入口

通常代码都有一个主程序入口,caffe也不例外.一切还得从下面这条语句说起:

1
./build/tools/caffe train -solver=examples/kaggle_mnist/lenet_solver.prototxt

显然这是一条命令行(Command Line)语句.通过查caffe官网关于Command Line 接口的说明,发现它有四种模式:train ,test ,time ,device_query.显然上面命令行中的train指代其中一种模式,后面的一长窜指命令行参数solver以及它的值.

接下来就得分析caffe这条命令,它的源码就在caffe/tools/caffe.cpp里.下面是它的main函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
int main(int argc, char** argv) {
// Print output to stderr (while still logging).
FLAGS_alsologtostderr = 1;
// Set version
gflags::SetVersionString(AS_STRING(CAFFE_VERSION));
// Usage message.
gflags::SetUsageMessage("command line brew\n"
"usage: caffe <command> <args>\n\n"
"commands:\n"
" train train or finetune a model\n"
" test score a model\n"
" device_query show GPU diagnostic information\n"
" time benchmark model execution time");
// Run tool or show usage.
caffe::GlobalInit(&argc, &argv);
if (argc == 2) {
#ifdef WITH_PYTHON_LAYER
try {
#endif
return GetBrewFunction(caffe::string(argv[1]))();
#ifdef WITH_PYTHON_LAYER
} catch (bp::error_already_set) {
PyErr_Print();
return 1;
}
#endif
} else {
gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
}
}

首先要对argc,argv有所了解:
argc是命令行总的参数个数
argv[]是argc个参数,其中第0个参数是程序的全名(即main)
argv[]剩余的参数是命令行后面用户输入的参数.

对应本例argc=2,argv[0]=main,argv[1]=train.别急,还有一个solver参数,它是有gflags依赖库来解析的.在caffe.cpp中有很多类似下面代码的定义:

1
DEFINE_string(solver, "","The solver definition protocol buffer text file.");

其中,DEFINE_string是gflags中的函数,把solver解析成FLAGS_solver,其值为examples/kaggle_mnist/lenet_solver.prototxt.不过这都不是本节的重点.以上的解析工作都是在下面这条语句完成的:

1
caffe::GlobalInit(&argc, &argv);

接下来就是重头戏,就是这条语句:

1
return GetBrewFunction(caffe::string(argv[1]))();

仔细看这条语句是执行一个函数,并返回.程序到此就结束了.瞬间这条函数重要性就突显出来了,因为前面分析的过程其实我们什么也没有做,那么这条函数就是核心代码了.

GetBrewFunction函数返回的是键值为argv[1]的map值,那么我们只要知道这个map是什么就好办了.

仔细看源码,在main函数之前,有如下四条语句:

1
2
3
4
RegisterBrewFunction(device_query);
RegisterBrewFunction(train);
RegisterBrewFunction(test);
RegisterBrewFunction(time);

其中,RegisterBrewFunction是一个宏,定义了一个namespace,有一个类以及它的构造函数, 且构造函数里有一个全局map,map的key为string类型,value为与key同名的函数指针.

由于上面的四条语句在main函数之前,那么他们相应的构造函数也在main之前执行了.也就是说进入main函数之前有map了.

因此 GetBrewFunction(caffe::string(argv[1]))返回的就是argv[1],在本例中也就是train了.

我们重要找到那个重要的函数了,它就是train()

train函数

其实train函数中也没有干核心的活,而是把它扔给了别人:

1
2
3
4
5
6
7
if (gpus.size() > 1) {
caffe::P2PSync<float> sync(solver, NULL, solver->param());
sync.Run(gpus);
} else {
LOG(INFO) << "Starting Optimization";
solver->Solve();
}

不过它也是干了一些活的,还记得前面说的solver的解析吗?它的主要工作就在这儿,就是把配置文档解析到一个SolverParameter.具体的过程这里就不写了.

Solve函数

只考虑CPU-only的情况.问题就全归结到:

1
solver->Solve();

首先声明,caffe中有很多solver,那么到底执行哪一个,因为我们不可能在solver这个层面写具体是哪一个solver,否则我们就会对每一个solver实现一次执行的详细过程.其实解决的方法,train函数的生成是一样的,也是注册一个map,最后根据传进来的solver做映射

其实在solve函数里也没有做实质的工作,任务都扔给了step函数

Step函数

step函数则抽象的规划了整个迭代过程,做了前向计算.剩余的都扔给了具体用到的子solver的ApplyUpdate函数,也就是说这两个函数才是正真的核心.

强力推荐一下两篇,作者写得很详细:

caffe 命令行接口源码分析

caffe solver相关源码分析