XGBoost_源码初探
XGBoost_ 源码初探
##1. 说明
本篇来读读 Xgboost 源码。其核心代码基本在 src 目录下,由 C++ 实现,40 几个 cc 文件,代码 11000 多行,虽然不算太多,但想把核心代码都读明白,也需要很长时间。我觉得阅读的目的主要是:了解基本原理,流程,核心代码的位置,修改从哪儿入手,而得以快速入门。因此,需要跟踪代码执行过程,同时查看在某一步骤其内部环境的取值情况。具体方法是:单步调试或在代码中加入一些打印信息,因此选择了安装编译源码的方式。
##2. 下载编译
用参数 --recursive 可以下载它的支持包 rabit 和 cur,否则编不过
1 | $ git clone --recursive https://github.com/dmlc/xgboost |
##3. 运行
测试程序 demo 目录中有多分类,二分类,回归等各种示例,这里从二分类入手。
1 | $ cd demo |
##4. 主流程
下面从 main() 开始,看看程序执行的主要流程,下图是一个示意图,每个黄色框对应一个 cc 文件,可以将它视作调用关系图,并非完全按照类图绘制,同时省略了一些主流程以外的细节,请各位以领会精神为主。
1) Src/cli_main.cc:(主程序入口)
CLIRunTask():解析参数,提供三个主要功能:训练,打印模型,预测.
CLITrain():训练部分,装载数据后,主要调用学习器 Learner 实际功能(配置 cofigure,迭代,评估,存储……),其中的 for 循环包含迭代调用计算和评估。
2) Src/learner.cc:(学习器)
定义三个核心句柄 gbm_(子模型 tree/linear),obj_(损失函数),metrics_(评价函数)
UpdateOneIter():此函数会在每次迭代时被调用,主要包含四个步骤:调整参数(LazyInitDMatrix()),用当前模型预测(PredictRaw(),gbm_-> PredictBatch()),求当前预测结果和实际值的差异的方向(obj_->GetGradient()),根据差异修改模型(gbm_->DoBoost()),后面逐一细化。
EvalOneIter() 支持对多个评价数据集分别评价,对每个数据集,先进行预测(PredictRaw()),评价(obj_->EvalTransform()),再调 metrics_ 中的各个评价器,输出结果。
3) Src/metric/metric.cc(评价函数入口)
基本上,每个目录都有一个入口函数,metric.cc 是评价函数的入口,learn 允许同时支持多个评价函数(注意评价函数和误差函数不同)。主要三种评价函数:多分类,排序,元素评价,分别定义在三个文件之中。
4) Src/objective/objective.cc(损失函数入口)
objective.cc 是损失函数的入口,Learner::load() 函数调用 Create() 创建误失函数,该目录中实现了:多分类,回归,排序的多种损失函数(每个对应一个文件),每个损函数最核心的功能是 GetGradient(),另外也可以参考 plugin 中示例,自定义损失函数。例如:src/objective/regression_obj.cc(最常用的损失函数 RegLossObj())计算一阶导,二阶导,存入 gpair 结构。这里加入了样本的权重,scale_pos_weight 也是在此处起作用。
5) src/gbm/gbm.cc(迭代器 Gradient Booster)
这里是对模型的封装,主要支持 tree 和 linear 两种方式,树分类器又包含 GBTree 和 Dart 两种,Dart 主要加入了归一化和 dropout 防过拟合,详见参考部分。gbm.cc 中也有三个重要句柄:model_ 存储当前模型数据,updaters_ 管理每一次迭代的更新算法,predictor_ 用于预测
DoBoost() 和 BoostNewTrees() 进一步迭代生成新树,详建更新器部分
Predict*() 调用各种预测,详见预测部分
6) src/predictor/predictor.cc(预测工具入口)
predictor.cc 也是一个入口,可调用 cpu 和 gpu 两种预测方式。
- PredValue():核心函数,计算了从训练到当前迭代的所有回归树集合(以回归树为例)。
7) src/tree/tree_updater.cc(树模型的具体实现)
src/tree 和 src/linear 分别是树和线性模型的具体实现,tree_updater 是 updater 的入口,每一个 Updater 是对一棵树进行一次更新。其中的 Updater 分为两类:计算类和辅助类,updater都继承于 TreeUpdater,互相之间又有调有关系,比如:prune 调用 sync,colmaker 和 fast_hist 调用 prune。
以下为辅助类:
Src/tree/updater_prune.cc 用于剪枝
Src/tree/updater_refresh.cc 用于更新权重和统计值
Src/tree/updater_sync.cc 用于在分布式系统的节点间同步数据
Src/tree/split_evaluator.cc 定义了两种切分方法:弹性网络 elastic net, 单调约束 monotonic,在此为切分评分,正则项在此发挥作用。打发的依据是差值,权重和正则化项。
以下为算法类(基本都在 xgboost 论文第三章描述) 对于树算法,最核心的是如何选择特征和特征的切分点,具体原理请见 CART,算法,信息增益,熵等概念,这里实现的是几种树的生成方法。
- Src/tree/updater_colmaker.cc 贪婪搜索算法 (Exact Greedy Algorithm),最基本的树算法,一般都用它举例说明,这里提供了分布和非分布两种支持。在每个特征中选择该特征下的每个值作为其分裂点,计算增益损失。由内至外,关键函数分别是:EnumerateSplit() 穷举每一个枚举值,用 split_evaluator 打分。ParallelFindSplit() 多线程,其它同上 UpdateSolution() 调上面两个 split(),更新候选方案 FindSplit() 在当前层寻找最佳切分点,对比各个候选方案,方案来自上面的 UpdateSolution()
Src/tree/updater_histmaker.cc 它是 xgboost 默认的树生成算法,它和后面提到的 skmaker 都继承自 BaseMaker(BaseMaker 的父类是 TreeUpdate)是基于直方图选择特征切分点。HistMaker 提取 Local 和 Global 两种方式,Global 是学习每棵树前,提出候选切分点;Local 是每次分裂前,重新提出候选切分点。UpdateHistCol() 对每一个 col,做直方图分箱,返回一个分界 Entry 列表。
Src/tree/updater_skmaker.cc 继承自 BaseMaker(BaseMaker 父类 TreeUpdate)加权分位数草图,用子集替代全集,使用近似的 sketch 方法寻找最佳分裂点。
##5. 其它
1) GPU,多线程,分布式
代码中也有大量操作 GPU,多线程,分布式的操作,这里主要介绍核心流程,就没有提及,详见代码,其中.cu 和.cuh 是主要针对 GPU 的程序。
2) 关键字说明
CSR:csr_matrix 一种存储格式
Dmlc(Deep Machine Learning in Common):分布式深度机器学习开源项目
Rabit:可容错的 allrecude(分布式),支持 python 和 C++,可以运行在包括 MPI 和 Hadoop 等各种平台上面
Objective 与 Metric(Eval):这里的 Metric 和 Eval 都指评价函数,Objective 指损失函数,它们计算的都是实际值和预测值之间的差异,只是用途不同,Objective 主要在生成树时使用,用于计算误差和通过误差的方向调整树;而评价函数主要用于判断模型对数据的拟合程度,有时通过它判断何时停止迭代。
3) 基于直方图的切分点选择
分位数 quantiles:即把概率分布划分为连续的区间,每个区间的概率相同。把数值进行排序,然后根据你采用的几分位数把数据分为几份即可。
xgboost 用二阶导 h 对分位数进行加权,让相邻两个候选分裂点相差不超过某个值ε。因此,总共会得到 1/ε个切分点。 通过特征的分布,按照加权直方图算法确定一组候选分裂点,通过遍历所有的候选分裂点来找到最佳分裂点。它不会枚举所有的特征值,而是对特征值进行聚合统计,然后形成若干个 bucket(桶),只将 bucket 边界上的特征值作为 split point 的候选,从而获得性能提升,对稀疏数据效果好。
####6. 参考
1) XGBoost Documentation
https://xgboost.readthedocs.io/en/latest/
2) xgboost 入门与实战(原理篇)
https://blog.csdn.net/sb19931201/article/details/52557382
3) XGBoost 解析系列 -- 源码主流程
https://blog.csdn.net/matrix_zzl/article/details/78699605
4) XGBoost 论文翻译 + 个人注释
https://blog.csdn.net/qdbszsj/article/details/79615712
5) DART booster
https://blog.csdn.net/Yongchun_Zhu/article/details/78745529