通过强化学习实现图神经网络的类平衡与快速主动学习

时间:2026年3月23日
来源:Knowledge-Based Systems

编辑推荐:

提出GraphCBAL和GraphCBAL++,通过强化学习框架解决图神经网络主动学习中的类不平衡问题,设计兼顾信息量和类平衡的奖励函数,并扩展批量主动学习实现高效多节点采样。实验表明该方法在多个数据集上有效平衡分类性能与类分布。

广告
   X   

余成城|朱家鹏|李翔
上海工程技术大学计算机与信息工程学院,中国上海

摘要

图神经网络(GNNs)最近取得了显著的成功。针对GNNs的主动学习旨在从未标记的数据中查询有价值的样本进行注释,以低成本最大化GNNs的性能。然而,大多数现有的GNNs强化主动学习方法可能导致类别分布极度不平衡,尤其是在类别严重偏斜的情况下。这进一步影响了分类性能。为了解决这个问题,本文提出了一种新颖的强化类平衡主动学习框架——GraphCBAL。它学习了一种最优策略,以获取类平衡且信息丰富的节点进行注释,从而最大化使用选定标记节点训练的GNNs的性能。GraphCBAL设计了类平衡感知的状态以及一个在模型性能和类别平衡之间取得平衡的奖励函数。我们进一步通过引入惩罚机制将GraphCBAL升级为GraphCBAL++,以获得更加类平衡的标记集。为了提高GraphCBAL的效率,我们提出了BGraphCBAL,这是一种批量模式扩展,它在每次迭代中选择多个节点进行标记。我们将批量主动学习表述为一个合作的多智能体强化学习问题。我们还设计了一个多智能体策略网络,该网络不仅考虑了候选节点的信息丰富性和类别平衡性,还模拟了同一批次内选定的节点之间的相互作用。在多个数据集上的广泛实验证明了所提出方法的有效性和效率,其性能优于现有的最佳基线方法。特别是,我们的方法能够在分类结果和类别平衡之间取得平衡。我们的代码和数据可以在https://github.com/cici-chengcheng/GraphCBAL获取。

引言

图神经网络(GNNs)由于在各种下游任务中的成功而受到了广泛关注,例如节点分类[1]、[2]、链接预测[3]和图分类[4]。例如,早期模型GCN[1]使用简化的一阶近似在谱空间中聚合节点特征,而GraphSage[2]直接在空间域中聚合来自节点邻居的特征。尽管性能吸引人,但大多数GNN模型通常需要大量的标记数据进行训练,而这些数据的获取通常成本很高。
针对这一挑战,GNNs的主动学习作为一种有前景的策略应运而生,其目标是在有限的预算内动态查询未标记数据中最有价值的样本进行注释,以最大化GNNs的性能。已经提出了一些有效的图主动学习方法[5]、[6]、[7]、[8]。然而,这些方法选择的样本可能表现出高度不平衡的类别分布,特别是在类别严重偏斜的情况下。例如,如图1(a)中的粉线所示,基准数据集Coauthor_Physis(Co-Phy)[9]是一个将作者分类为五个类别的合著网络,其类别分布高度不平衡,按照相应类别中的样本数量递减排序。具体来说,第一个类别有大量的样本(也称为头部类别),而其余类别只有少量样本(也称为尾部类别)。我们使用了两种最先进(SOTA)的主动学习方法GPA[8]、ALLIE[10]和GreedyET[11](详细信息将在第2节中介绍),以及我们提出的方法GraphCBAL++,在数据集上获取相同数量的节点进行注释。从图1(a)可以看出,GPA和ALLIE获得的标记节点的类别分布都高度不平衡,大多数节点来自头部类别,而GraphCBAL++获得了平衡的结果。为了进一步展示所选节点的有效性,我们使用这些标记节点训练GCN并进行了节点分类。结果如图1(b)所示,GPA和ALLIE在标记节点最少的尾部类别4上获得了较差的结果。相比之下,我们的方法在不同类别上表现出了稳健的性能。
类别不平衡的标记训练数据可能导致分类结果中的类别不平衡问题[12]、[13]。由于典型的GNNs在设计时没有考虑这个问题,使用类别不平衡的标记数据训练GNNs可能会引入对多数类别的预测偏差,从而导致整体性能下降。同时,在现实世界应用中,类别不平衡现象很普遍,例如在欺诈检测[14]中,交易网络中的欺诈者数量远少于良性实体;在引文网络的主题分类[15]和社交网络的机器人检测[16]中也是如此。因此,对于主动学习来说,选择类平衡且有价值的样本进行注释以避免类别不平衡问题至关重要。尽管在计算机视觉[17]、[18]、[19]、[20]、[21]中已经对不平衡类别的主动学习进行了充分研究,但大多数现有方法基于独立同分布(i.i.d.)假设,直接将其应用于图可能会不合适或无效。这进一步促使我们研究类平衡的图主动学习,该方法在选择“信息丰富”和“类平衡”的节点进行标记时应该考虑图结构。
在本文中,我们提出了一种图类的平衡主动学习方法——GraphCBAL。它采用强化学习(RL)框架,学习一种最优策略来查询类平衡且信息丰富的样本进行注释,从而最大化使用选定标记节点训练的GNNs的性能。具体来说,我们将类平衡主动学习表述为一个马尔可夫决策过程(MDP),并学习最优查询策略。状态基于当前图的状态定义,该状态考虑了节点的信息丰富性和类别平衡;动作是在每次查询步骤中选择一个节点进行标记;奖励函数考虑了使用选定节点训练的GNNs的性能提升以及预定义的类别多样性得分,这可以提高性能和类别平衡。为了获得更加类平衡的标记集,我们通过引入惩罚机制将GraphCBAL升级为GraphCBAL++,在奖励函数中添加了一个惩罚项,强制选择少数类别的节点。为了更稳定和有效的训练,我们使用了优势演员-评论家(A2C)算法[22]来学习查询策略,其中演员网络和评论家网络由两个GCN组成,同时考虑了信息丰富性和节点之间的相互依赖性。
此外,大多数现有的主动学习方法[5]、[6]、[7]、[8]以及我们提出的GraphCBAL都是逐个选择样本的,这使它们在处理大规模数据集时效率低下。为了进一步提高主动学习方法的效率,批量主动学习是一种更有前景的方法,它在每次迭代中为多个样本获取标签。然而,批量主动学习面临以下挑战:首先,批量模式策略必须在每次迭代中选择多个“信息丰富”和“类平衡”的节点;其次,每个动作对应于选择多个节点的组合,导致动作空间中的组合爆炸。尽管BIGENE[23]引入了一个多智能体强化学习框架来查询每个步骤中的多个样本,但它仍然忽略了由于类别不平衡导致的性能下降问题。
为了解决上述挑战,我们提出了一种批量模式方法BGraphCBAL,以提高GraphCBAL的效率。它学习了一个多智能体策略,在每次迭代中选择多个节点。该策略包括两个组件:一个GCN模块确定第一个标记节点,以及一个GRU模块通过模拟同一批次内选定的节点之间的相互作用来选择其余节点。奖励函数由使用选定批次训练获得的性能提升以及选定节点的聚合类别多样性组成。我们采用集中训练与分散执行(CDTE)范式[24]来训练策略,其中集中式评论家协助训练分散式策略智能体以实现全局合作,所有智能体在每个时间步骤共享相同的全球奖励。
最后,我们在本文中的主要贡献总结如下:
  • 我们提出了GraphCBAL,这是一种针对图的类平衡主动学习方法。据我们所知,我们是第一个将类平衡引入图强化学习的方法。
  • 我们设计了一个有效的奖励函数,可以在分类性能和类别平衡之间取得平衡。我们还引入了一个类平衡感知的状态空间,用于采样信息丰富的节点。
  • 为了提高GraphCBAL的效率,我们提出了一种批量模式方法BGraphCBAL,它在每次迭代中选择多个节点进行标记。我们设计了一个多智能体策略网络,该网络不仅考虑了候选节点的信息丰富性和类别平衡性,还考虑了同一批次内选定的节点之间的相互作用。
  • 我们在七个基准数据集上进行了广泛的实验,证明了我们方法的有效性和效率。它始终产生更加类平衡的标记集,与现有最佳方法相比,分类性能相当或更优,同时显著提高了效率。
  • 章节片段

    图上的主动学习

    主动学习已在计算机视觉和自然语言处理等多个领域得到广泛研究。最近有几项工作专注于图结构数据的主动学习[5]、[6]、[7]、[11]、[25]、[26]、[27]、[28]。例如,AGE[5]通过三个标准的线性组合来衡量节点的信息丰富性,包括信息熵、密度和中心性。它从未标记的节点中找到最信息丰富的节点。ANRMAB[6]扩展了AGE,使用了相同的标准,并

    问题定义

    G=(V,E)表示一个图,其中V是节点集,E是边集。ARN×N是邻接矩阵,XRN×d是节点特征矩阵。每个节点vV都有一个标签c(v)∈{1, ..., m),m是节点类别的数量。节点集被分为三个子集,包括Vtrain、Vvalid和Vtest。在传统的半监督节点分类中,子集LVtrain的标签是给定的。任务是使用图G和标签L学习一个分类f来预测节点

    方法论

    在本节中,我们描述了我们提出的模型。我们将图上的类平衡主动学习问题表述为一个马尔可夫决策过程(MDP)。具体来说,状态基于当前图G的状态和训练好的GNN分类f来定义。我们将步骤t的状态表示为St。由θ参数化的主动学习策略πθ通过选择下一个节点进行查询来采取行动。为了提高分类性能和标记的类别平衡

    BGraphCBAL:图的批量类平衡主动学习

    GraphCBAL在每次主动学习(AL)迭代中选择一个节点进行注释,这需要训练GNN分类器以生成更新的状态。然而,GNN训练往往成为逐个主动学习过程中的主要计算瓶颈,由于重复训练而在大规模数据集上效率低下。为了说明这一点,我们报告了GraphCBAL在所有数据集中每次主动学习迭代中每个组件的运行时间。

    数据集

    我们使用了七个广泛使用的基准数据集,包括Cora、Citeseer、Pubmed、Reddit、Coauthor-CS(Co-CS)、Coauthor-Physics(Co-Phy)[9]和ogbn-arxiv(ArXiv)[61]。前三个数据集是引文图,节点代表文档,边代表引用。Reddit是一个在线论坛数据集,其中节点代表帖子,如果至少有两个用户对帖子进行评论或发布,则这两个帖子通过边连接。Co-CS和Co-Phy是连接作者节点的合著网络

    结论

    在本文中,我们提出了GraphCBAL,这是一种新颖的强化和类平衡的GNNs主动学习方法。它学习了一种最优策略来获取类平衡且信息丰富的节点进行注释,从而最大化在选定标记节点上训练的GNNs的性能。引入了一个类平衡感知的状态空间,奖励函数基于性能提升和类别多样性设计,旨在在模型性能和类别平衡之间取得平衡。同时,我们

    CRediT作者贡献声明

    余成城:写作——审阅与编辑,撰写——原始草稿,可视化,软件,方法论,资金获取,概念化。朱家鹏:验证,软件,调查。李翔:监督,资金获取,撰写——原始草稿,写作——审阅与编辑。

    利益冲突声明

    作者声明他们没有已知的竞争财务利益或个人关系可能影响本文报告的工作。

    生物通微信公众号
    微信
    新浪微博


    生物通 版权所有