Jun, 2024

基于FT-TABPFN模型的表格分类的标记化特征增强

TL;DR传统的表格分类方法通常依赖于从头开始的有监督学习,需要大量训练数据来确定模型参数。然而,一种名为Prior-Data Fitted Networks(TabPFN)的新方法改变了这一范式。TabPFN使用在大型合成数据集上进行训练的12层变压器来学习通用的表格表示。这种方法能够通过单次前向传递快速和准确地对新任务进行预测,且无需额外的训练。虽然TabPFN在小型数据集上表现出色,但处理分类特征时通常表现较弱。为了克服这一限制,我们提出了FT-TabPFN,它是TabPFN的增强版本,包括了一种新颖的特征标记化层来更好地处理分类特征。通过针对下游任务进行微调,FT-TabPFN不仅扩展了原始模型的功能,而且在表格分类中显著提高了其适用性和准确性。我们的完整源代码可供社区使用和开发。