GraphSAGE:归纳式学习节点嵌入的聚合函数
GraphSAGE 归纳式图学习教程
图数据在社交网络、推荐系统、知识图谱等领域无处不在。传统的图嵌入方法(如 DeepWalk、Node2Vec)和早期的图神经网络(如 GCN)大多是直推式的——它们只能为训练阶段出现过的节点生成嵌入,一旦图结构或节点集合发生变化就需要重新训练。GraphSAGE(SAmple and aggreGatE)正是为突破这一限制而生的归纳式图学习框架。本教程将从零开始带你理解 GraphSAGE 的核心原理、聚合函数、训练方式,并通过伪代码帮助你快速落地。
1. 什么是归纳式图学习?
在进入 GraphSAGE 之前,我们先厘清一个关键概念。
| 类型 | 特点 | 典型方法 | 限制 |
|---|---|---|---|
| 直推式 (Transductive) | 训练期间必须看到所有节点的结构信息,学出的嵌入表是固定大小的。 | DeepWalk, Node2Vec, GCN (原始) | 新节点加入时必须重训练,无法泛化到未见过的图。 |
| 归纳式 (Inductive) | 学习一个函数,该函数利用节点特征和邻居结构来生成嵌入。新节点到来时,只需调用函数即可,无需重新训练。 | GraphSAGE, GAT, GraphSAINT | 模型需要能处理未知度的节点。 |
GraphSAGE 正是归纳式方法的奠基之作。它不直接为每个节点维护一个嵌入向量,而是训练一组聚合函数,通过聚合邻居特征来生成节点表示。这使得模型可以平滑地扩展到数十亿节点的动态图。
2. GraphSAGE 核心思想
GraphSAGE 的全称是 SAmple and aggreGatE,其名字揭示了两个关键步骤:
- 采样 (Sample):对每个目标节点,均匀随机采样固定数量的邻居(而不是使用全部邻居),以保证计算图大小可控。
- 聚合 (Aggregate):定义一个可微的聚合函数,将采样邻居的表示聚合为一个向量,并与目标节点自身的特征拼接,经过非线性变换生成该节点的嵌入。
因为学习的是生成嵌入的函数而非查表式嵌入,GraphSAGE 天然具备归纳能力——新节点只要有特征和连接关系,就能立即生成嵌入。
3. 算法前向传播详解
假设图中有节点特征 $\mathbf{x}_v \in \mathbb{R}^F$(如文本词袋、用户属性),模型参数为 $K$ 个聚合器($K$ 层)。第 $k$ 层的计算过程如下:
3.1 邻居采样
对于目标节点 $v$,从其邻居集合 $\mathcal{N}(v)$ 中均匀无放回采样 $S$ 个节点,记为 $\mathcal{N}_S(v)$。如果邻居数少于 $S$,则使用有放回采样补足。
这种固定大小的采样使得每个节点在批量训练中花费的计算和存储一致,且能避免“邻居爆炸”问题(高度数节点的邻居过多)。
3.2 聚合与更新
第 $k$ 层($k \in {1, ..., K}$)的传播公式为:
$$\mathbf{h}^{k}_{\mathcal{N}(v)} = \text{AGGREGATE}_k\big({\mathbf{h}^{k-1}_u, \forall u \in \mathcal{N}_S(v)}\big)$$
$$\mathbf{h}^{k}_v = \sigma\Big( \mathbf{W}^k \cdot \text{CONCAT}\big(\mathbf{h}^{k-1}v, \mathbf{h}^{k}{\mathcal{N}(v)}\big) \Big)$$
最后对节点嵌入进行 L2 归一化(可选):
$$\mathbf{h}^{k}_v \leftarrow \frac{\mathbf{h}^{k}_v}{|\mathbf{h}^{k}_v|_2}$$
初始输入为节点原始特征:$\mathbf{h}^0_v = \mathbf{x}_v$。经过 $K$ 层迭代后,$\mathbf{h}^K_v$ 即为最终嵌入,它融合了 $K$ 跳邻居的信息。
要点:
- 每一层的聚合器 $\text{AGGREGATE}_k$ 和权重矩阵 $\mathbf{W}^k$ 是共享于所有节点的,这是归纳能力的关键。
- CONCAT 操作保留了节点自身与邻居信息,避免自身特征被过度稀释。
4. 聚合函数的选择
聚合函数必须满足两个条件:① 可处理变长的无序邻居集合;② 可微,方便反向传播。GraphSAGE 提出了三种经典聚合器:
4.1 均值聚合(Mean Aggregator)
最基础的聚合方式,对邻居向量逐元素求平均(不包含自身,随后拼接):
$$\mathbf{h}^{k}_{\mathcal{N}(v)} = \text{mean}\big({\mathbf{h}^{k-1}_u, \forall u \in \mathcal{N}_S(v)}\big)$$
均值聚合与 GCN 的传播规则非常相似,但它通过拼接自身向量(而非自环)来保留身份信息。
4.2 LSTM 聚合
利用 LSTM 处理邻居序列。由于 LSTM 天然对输入顺序敏感,而邻居是无序的,GraphSAGE 对邻居进行随机打乱后再输入 LSTM。这在实际中表现优异,尤其在结构复杂的图上。
$$\mathbf{h}^{k}_{\mathcal{N}(v)} = \text{LSTM}\big({\text{Randomly permuted } \mathbf{h}^{k-1}_u}\big)$$
4.3 池化聚合(Pooling Aggregator)
先对每个邻居向量通过一个可学习的全连接层,然后逐元素取最大池化(或均值池化),以捕获邻居中最显著的特征:
$$\mathbf{h}^{k}{\mathcal{N}(v)} = \max\big({\sigma(\mathbf{W}{\text{pool}} \mathbf{h}^{k-1}_u + \mathbf{b}), \forall u \in \mathcal{N}_S(v)}\big)$$
其中 $\sigma$ 可以是 ReLU,$\max$ 是逐元素取最大值。这种聚合器理论上可以学习到任意对称函数。
实际建议:在大部分任务中,均值聚合实现简单且稳定;当需要捕获更复杂的邻居模式时,最大池化表现更好。
5. 模型训练
GraphSAGE 支持两种典型的训练范式。
5.1 无监督训练
通过随机游走或负采样,使邻近节点的嵌入相似,远离节点的嵌入差异大。损失函数为:
$$\mathcal{L}(z_v) = -\log\big(\sigma(z_v^\top z_u)\big) - Q \cdot \mathbb{E}{v_n \sim P_n(v)} \log\big(\sigma(-z_v^\top z{v_n})\big)$$
其中 $u$ 是与 $v$ 在固定长度随机游走中共现的邻居,$v_n$ 是从负采样分布 $P_n(v)$ 中采样的负样本,$Q$ 是负样本数。此损失鼓励相邻节点嵌入的余弦相似度高。
5.2 监督训练
直接将节点嵌入输入到下游分类或回归任务中,例如使用交叉熵损失进行节点分类。
无论哪种范式,训练时都采用小批量:每个批次随机采样若干目标节点,然后为这些节点构建 $K$ 层计算图(通过迭代采样邻居),最后在计算图上运行前向传播和反向传播。这种每个批次独立构图的机制完美适配大规模图。
6. 归纳能力的威力
GraphSAGE 的归纳优势体现在两大场景:
- 动态/增长图:社交网络中新注册的用户,推荐系统中新上架的商品,只需用该节点特征和其连接关系,立即生成嵌入,无需重新训练模型。
- 跨图泛化:在蛋白质相互作用图中,可以在一个物种的图中训练 GraphSAGE,然后直接应用到另一个物种的图(只要节点特征维度一致),这是直推式方法完全无法做到的。
实验表明,即使新节点完全未在训练图中出现,GraphSAGE 的准确率下降幅度也很小。
7. 伪代码与实现要点
以下伪代码描述 GraphSAGE 单次前向传播的关键步骤:
输入: 目标节点集合 B; 采样邻居数 S; 层数 K; 特征矩阵 X; 邻接列表
输出: 节点嵌入 Z
for k = 1...K:
构建本层计算图:
for 每个节点 v in B:
N_S(v) = 均匀采样 S 个邻居(如果邻居不足则重复采样)
将 N_S(v) 加入本层需要的节点集合 B^k
将特征矩阵 X 传播到上一层嵌入 H^{k-1}
for 每个节点 v in B:
h_N_v = AGGREGATE_k({H^{k-1}[u] for u in N_S(v)})
H^k[v] = σ( W^k * CONCAT(H^{k-1}[v], h_N_v) )
对 H^k 可选地 L2 归一化
B = B^k (进入下一层时,当前目标节点变为邻居的邻居)
Z = H^K
在 PyTorch 或 TensorFlow 中,通常使用稀疏操作和 scatter 函数实现高效的聚合。每一层聚合器可作为单独的 Module,并共享权值。
8. GraphSAGE 的优缺点
| 优点 | 缺点 |
|---|---|
| 归纳式学习,适用于动态图和新节点。 | 采样过程引入随机性,需仔细设置采样数 S 和层数 K。 |
| 小批量训练,内存友好,可扩展至超大规模图。 | 均匀随机采样可能丢失重要的高影响力邻居。 |
| 灵活的聚合函数,可根据任务选择。 | LSTM 的顺序敏感性可能不够稳定。 |
| 对特征丰富的数据效果突出。 | 纯结构图(无节点特征)时,需要人工构造特征。 |
9. 应用建议与总结
当你在实际项目中决定是否使用 GraphSAGE 时,可以遵循以下指南:
- 节点特征充足(如文本、属性、图像特征)→ GraphSAGE 比纯拓扑方法更好。
- 图不断变化,需要频繁对新节点做预测 → 归纳式方法首选 GraphSAGE。
- 图极大,无法全图一次性加载 → 小批量采样的 GraphSAGE 是定心丸。
- 需要可解释性 → Mean Pooling 的邻居贡献更直观。
GraphSAGE 将图学习从“记住每个节点”推进到“学会如何聚合邻居”的范式,奠定了后续大量图神经网络(如 GAT、GraphSAGE 变体)的基础。掌握它的核心机制——采样、聚合、归纳泛化——将使你在图数据实战中游刃有余。
继续学习:尝试自己实现一个拥有均值聚合与最大池化聚合的 GraphSAGE 层,并在 Cora 或 Reddit 数据集上比较直推式 GCN 与归纳式 GraphSAGE 对新节点的分类性能。你会发现,归纳能力是走向真实图系统的关键一步。