TabPFN
1 简介
TabPFN(Tabular Prior-Data Fitted Network)是由 Meta AI 团队开发的针对表格数据的神经网络分类器。
1.1 主要特点是
- 无需超参数调优:TabPFN 与 XGBoost、LightGBM 等传统树模型不同,它即插即用,不需调整超参数。
- 极快的推理速度:TabPFN 训练和预测时间不到一秒,远超深度学习模型。
- 基于 Transformer:尽管规模小,它在预训练时利用大量合成数据,实现对新数据集的良好泛化能力。
- 适合小型数据集(<10K 样本):TabPFN 在小数据集上表现出色,但对大规模数据暂不是最佳。
1.2 适用场景
- 对于小型表格数据集,希望能够快速获得高质量的分类结果。
- 不希望投入太多时间在参数调优上,需要一个即插即用的分类器。
2 原理
TabPFN 通过元学习(Meta-Learning)预训练一个 Transformer,让其学习各种合成数据的模式,从而可以在新数据集上实现零训练和极速推理。对于小型表格数据分类任务,它是一个开箱即用且无需调参的强力工具。
3 阶段
- 预训练阶段(Offline Learning):
- 数据生成:在预训练时,TabPFN 没有使用真实的表格数据,而是通过神经过程(Neural Processes)和其他生成方法,创造了几百万个合成的小型分类任务(不同的特征、分布、噪声等)。
- 学习任务分布:模型通过学习这些任务的共同模式,相当于“见过”各种可能的数据集,从而掌握预测不同类型数据的能力。
- Transformer 作为先验模型(Transformer as Prior):
- 在训练过程中,使用小型 Transformer(大约 10M 参数)对这些合成数据的分类方式进行建模。
- 这个 Transformer 不是在单个数据集上训练,而是基于“数据集的分布”进行训练,类似于学习了各种数据分布,能够泛化到新任务。
- 推理阶段(Inference):
- 无需再训练:当拿到一个新的表格数据集时,TabPFN 不需要再经过训练,而是直接将数据输入到 Transformer,使其进行预测。
- 基于任务先验(Prior Knowledge):TabPFN 不是在新数据上优化参数,而是基于它在预训练阶段学到的先验信息进行直接分类。
- 极快的推理速度:由于 Transformer 的推理过程是一次性前向计算(forward pass),整个分类过程可以在毫秒级别完成。
4 用法
1 | pip install tabpfn |
1 | from tabpfn import TabPFNClassifier |
5 思考
这个模型拥有深刻的数据理解能力,能够根据具体数据找到最佳解决方案。在某些情况下,其表现可与 XGBoost 类模型相当甚至更好。有空尝试一下在金融领域的数据中应用。
All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.