对比学习 MoCo:动量编码器与动态字典
对比学习 MoCo:动量编码器与动态字典
在自监督学习中,对比学习通过拉近相似样本(正样本对)的表示、推开不相似样本(负样本对)的表示,学习强大的特征提取器。MoCo(Momentum Contrast)由何恺明等人提出,核心在于将字典看作队列并用动量更新的编码器维护这个字典,从而在有限显存下获得海量且一致的负样本,极大提升了对比学习的性能。
为什么需要大规模的负样本?
对比学习的损失函数(如 InfoNCE)可形式化为:
[ \mathcal{L}q = -\log \frac{\exp(q \cdot k+ / \tau)}{\exp(q \cdot k_+ / \tau) + \sum_{i=1}^{K} \exp(q \cdot k_i / \tau)} ]
其中 ( q ) 是锚点(查询)的表示,( k_+ ) 是其正样本的表示,( k_i ) 是负样本的表示,( \tau ) 为温度系数。想要学到有区分力的特征,负样本的数量 ( K ) 通常越大越好。
传统方法受限于两个因素:
- 端到端方式(如 SimCLR):使用同一个 batch 中的其他样本作为负样本,( K ) 受 batch size 限制,而大 batch size 需要极高的显存。
- 记忆库(Memory Bank):离线存储所有样本的特征,但更新不及时,导致特征与当前模型不一致,训练不稳定。
MoCo 巧妙地将负样本抽象为一个动态字典,并分离出两个关键设计:队列字典和动量编码器。
MoCo 的核心设计
1. 队列字典 —— 解耦字典大小与 batch size
MoCo 维护一个固定大小的队列(默认长度 ( K = 65536 )),存放由动量编码器产生的负样本特征。
- 每个训练 step,当前 mini-batch 的正样本特征会被入队。
- 同时,队列中最老的一批特征会被出队。
- 所有负样本特征直接从队列中读取,不再依赖于当前 batch。
这样做的好处:字典大小与 batch size 完全解耦。即使使用很小的 batch size,也能拥有成千上万的负样本,且字典中的样本来自最近的多个 batch,保证了多样性。
队列字典更新示意:
Step t: 队列 [f0, f1, f2, ..., f65535] → 取出最老的 f0,压入当前正样本特征 f_new
Step t+1: 队列 [f1, f2, ..., f65535, f_new]
2. 动量编码器 —— 维护特征一致性
如果直接用查询编码器来更新队列中的特征,会导致编码器快速变化,队列里早先的特征与当前编码器的输出不再对齐,损害训练稳定性。MoCo 引入一个动量更新的 key 编码器:
- 查询编码器 ( f_q ):正常反向传播更新。
- key 编码器 ( f_k ):参数 ( \theta_k ) 由查询编码器参数 ( \theta_q ) 通过指数移动平均(EMA)更新: [ \theta_k \leftarrow m \cdot \theta_k + (1 - m) \cdot \theta_q ] 其中动量系数 ( m ) 通常设为接近 1 的值,如 0.999。
这样,key 编码器更新非常缓慢,保持了特征的平滑性和一致性,使得队列中虽然存放着不同时间步产生的特征,但它们都来自一个近似的编码器,比直接复制参数或快速变化更可靠。
3. 整体训练流程
- 从数据增强模块生成同一图像的两个不同视图 ( x^q ) 和 ( x^k )。
- ( x^q ) 送入查询编码器 ( f_q ),得到查询向量 ( q )。
- ( x^k ) 送入动量编码器 ( f_k ),得到 key 向量 ( k_+ )(正样本)。
- 将 ( k_+ ) 与队列中保存的负样本 key 向量拼接,作为字典 ( {k_0, k_1, ..., k_K} ),其中 ( k_0 ) 就是 ( k_+ )。
- 计算 InfoNCE 损失,反向传播只更新 ( f_q ) 的参数。
- 通过动量公式更新 ( f_k ) 的参数。
- 将当前 batch 的 key 特征 ( k_+ ) 入队,弹出最老的 key 特征。
# 伪代码:MoCo 核心逻辑
for x in loader:
x_q = aug(x) # 查询视图
x_k = aug(x) # key 视图
q = f_q(x_q) # 查询编码器,梯度更新
with torch.no_grad():
k = f_k(x_k) # 动量编码器,停止梯度
# 正样本为 k,负样本从队列中获取
l_pos = (q * k).sum(dim=1, keepdim=True) # 正样本分数
l_neg = (q @ queue.t()) # 负样本分数
logits = torch.cat([l_pos, l_neg], dim=1)
loss = CrossEntropyLoss(logits / tau, labels)
loss.backward()
optimizer.step() # 仅更新 f_q
# 动量更新 f_k
f_k.params = m * f_k.params + (1 - m) * f_q.params
# 入队新 key,出队最老 key
queue = torch.cat([queue, k], dim=0)[batch_size:]
MoCo 为什么有效?
- 大且一致的字典:队列提供数千个负样本,且动量缓慢更新的编码器保证这些特征来自几乎相同的“老师”,避免了记忆库的不一致性。
- 超高效率:无需超大 batch size,单张 GPU 甚至 CPU 也能训练出有竞争力的无监督表示。
- 即插即用:学习到的特征可直接迁移到分类、检测、分割等下游任务,只需简单微调。
后续改进如 MoCo v2 引入了更强的数据增强和投影头,MoCo v3 则解决了动量编码器在 ViT 训练中的稳定性问题,但其灵魂始终是队列字典 + 动量编码器这两个基石。
总结
MoCo 把对比学习中的负样本字典视为一个动态更新的队列,并采用动量缓慢更新的编码器来维护字典中的特征一致性。这一设计突破了批大小的限制,让无监督特征学习能够高效利用大量负样本,成为自监督视觉表征学习史上的里程碑。