在当今数据驱动的时代,表格数据——即以行和列形式组织的数据——无疑是现实世界机器学习任务中最常见的数据类型之一。从预测用户点击行为、评估房价,到预估医疗事件风险,其应用场景无处不在。长久以来,处理这类数据的“王者”一直是梯度提升决策树这类基于树状结构的模型。然而,近年来深度学习在计算机视觉和自然语言处理等领域的巨大成功,激发了人们探索深度学习模型处理表格数据的浓厚兴趣。
尽管热情高涨,但将深度学习模型应用于表格数据并非易事。一个核心难题在于,表格数据具有天生的“异构性”和“排列不变性”。与图像或文本不同,表格中的每一列特征(例如年龄、收入、职业类别)可能类型不同(数值型或分类型),取值范围和含义也各异。更重要的是,交换表格中列的顺序,通常不会改变数据本身的含义。这种特性使得那些严重依赖数据空间结构或序列位置的传统深度学习架构难以直接适用。此外,表格数据中特征之间的交互关系可能非常复杂(例如,教育背景、职业与收入之间的相互影响),如何有效捕捉和建模这些交互,成为了提升模型性能的关键瓶颈。
为了解决上述问题,并探索图神经网络在表格数据领域的潜力,来自新西兰惠灵顿维多利亚大学数学与统计学院的Pimwipa Charuthamrong、Colin R. Simpson和Binh P. Nguyen合作开展了一项研究。他们的目标是将表格数据转化为图结构,并设计一种新的图神经网络架构来学习特征间的复杂交互。这项研究成果最终发表在人工智能领域的知名期刊《Neural Networks》上。
为了开展这项研究,作者团队主要采用了以下几种关键技术方法:首先,他们为每个数据样本(表格中的一行)构建了一个全连接、无权的特征交互图,其中节点代表经过上下文编码的特征(对数值特征使用线性编码,对分类特征使用嵌入查找编码),边代表特征间的交互。其次,他们设计了一个基于图同构网络的消息传递图神经网络架构。该架构的核心创新在于使用一个多层感知机动态学习并更新多维边属性,同时借鉴GIN的卷积方式更新节点表示。最后,为了应对小规模特征图上的过平滑问题,他们在节点和边的更新方程中引入了残差连接机制。模型在包含分类和回归任务的12个公开数据集上进行了训练与评估。
3.1. 图构建
研究人员为每个数据样本(即表格中的一行)构建了一个全连接的无权图。图中的每个节点代表一个特征,该节点的初始表示是通过对数值特征和分类特征进行专门的上下文编码得到的。每条边则代表两个特征之间的潜在交互关系,其属性(即边的特征向量)在初始时并未设定,而是在后续的图神经网络训练过程中动态学习得到的。此外,图中还添加了一个特殊的“分类节点”,该节点与所有特征节点相连,旨在聚合全图信息,用于最终的预测任务。这种图构建方式直观地将特征间的复杂关系建模为图结构,为后续的图学习奠定了基础。该过程可通过Figure 1 形象展示。
3.2. 模型架构
模型架构的核心是一个受图同构网络启发的消息传递层。其关键创新在于同时学习并更新节点和边的表示。具体而言,对于每一条连接节点i和节点j的边,其更新的边属性e‘是通过一个多层感知机(MLP)对源节点、目标节点以及当前边属性进行非线性变换得到的。随后,在更新节点i的表示时,不仅会聚合其所有邻居节点j的表示,还会将对应边更新后的属性e’加到邻居节点的信息中,再通过求和聚合,最后通过另一个MLP得到节点i的新表示。这种设计使得边属性能够动态地捕捉和传递特征间交互的具体模式。更重要的是,为了缓解在小规模特征图上堆叠多层GNN时常见的过平滑问题,该模型在节点更新中引入了残差连接,将节点初始表示和上一层表示纳入计算;同样,边属性也通过残差连接进行更新。这种设计有助于保留不同层的特征信息,从而构建出更深、更强大的网络。Figure 2 清晰地展示了这一层的计算流程。
4. 实验
为了全面评估所提出模型(名为Edge-updating GNN)的性能,研究团队在12个公开可用的表格数据集上进行了广泛的实验,涵盖了二分类、回归和多分类任务。他们将模型与六种先进的表格深度学习模型(包括TabTransformer、FT-Transformer、ResNet、TabNet、ExcelFormer和另一个GNN模型INCE)进行了比较。此外,由于梯度提升决策树仍是表格数据领域的标杆,研究还比较了XGBoost和CatBoost。
4.4. 结果与讨论
实验结果表明,该Edge-updating GNN模型在综合表现上最佳。在采用调优超参数的情况下,该模型在12个数据集中取得了最佳的平均排名。具体而言,它在Adult、Credit、House、Bike和Med等多个数据集上取得了最优或并列最优的性能。尤为值得一提的是,即使在使用默认超参数时,该模型在所有数据集上的表现也均优于XGBoost和CatBoost。在调优后,它仍在超过半数的数据集中优于这两个强大的GBDT模型。
为了探究模型架构中各个组件的贡献,研究进行了消融实验。结果显示,同时包含边属性更新和残差连接的完整模型取得了最好的平均排名,而移除其中任何一个组件都会导致性能的轻微下降。这表明,边更新机制和残差连接对于在小规模特征图上有效学习特征交互、防止过平滑具有协同促进作用。
研究还将提出的卷积层替换为其他经典的GNN卷积层(如GAT、GCN、GIN、GraphSAGE以及较新的ADP-GNN)进行对比。在相同的图构建和解码器设置下,Edge-updating GNN在绝大多数数据集上表现最优,特别是在回归和多分类任务上优势明显。这证明了其设计的卷积方式能够更有效地从特征交互图中学习到高质量的表示。
最后,研究还探索了不同节点/边特征维度以及网络层数对性能的影响。总体趋势是更高的维度(如512维)能带来微弱的性能提升,但会显著增加训练时间,尤其是在大型数据集上。而就层数而言,使用2层的模型取得了最佳的综合性能,增加层数并未带来一致的提升,这进一步印证了在浅层网络上通过精细的边和节点更新机制来有效建模交互的重要性。
综上所述,这项研究提出了一种新颖的用于表格数据的图神经网络架构。通过构建特征交互图,并设计一种能够同时、动态地学习节点和边表示的消息传递机制,该模型成功地捕捉了表格数据中复杂的特征间关系。引入的残差连接有效缓解了过平滑问题,使得模型在相对浅层的网络上也能取得优异性能。广泛的实验验证了该模型在多种任务上相对于当前先进深度学习和传统梯度提升决策树模型的竞争力。这项工作不仅为表格数据的深度学习提供了一个强有力的新工具,也为理解和建模特征交互开辟了新的图表示学习视角。其代码已公开,为后续研究和应用提供了便利。
打赏