弹性权重巩固 EWC:基于 Fisher 信息的持续学习
什么是持续学习与灾难性遗忘
持续学习(Continual Learning),又称终身学习,是指模型在连续接收到不同任务的数据流时,能够不断学习新知识而不遗忘旧知识的能力。传统神经网络在依次训练多个任务时,会遭遇灾难性遗忘(Catastrophic Forgetting)——当用新任务的数据更新模型参数后,模型在旧任务上的表现会急剧下降。这是因为新任务的梯度更新会覆盖掉对旧任务重要的权重,导致旧知识被“擦除”。
为了解决灾难性遗忘,研究者提出了多种方法,主要包括:基于回放的方法、基于正则化的方法和参数隔离的方法。弹性权重巩固(Elastic Weight Consolidation,EWC)就是基于正则化的经典算法之一,它通过给重要参数的变化施加约束,让网络在学习新任务时“记住”旧任务。
EWC 的核心思想
EWC 的关键洞见是:并非所有权重对旧任务都同等重要。有些权重对旧任务的性能影响很大,有些则影响甚微。如果我们能找出那些“关键”权重,并在训练新任务时限制它们的剧烈变化,就可以保护旧知识。同时,不重要的权重仍然可以自由调整以适应新任务。
这一思想类似于物理中的弹簧系统:所有权重都被虚拟的弹簧连接回旧任务的解,弹簧的刚度与其对旧任务的重要性成正比。弹簧越硬,权重越难改变。数学上,这体现为一个在旧任务最优参数附近的二次惩罚项。
理论基础:Fisher 信息矩阵
EWC 用 Fisher 信息矩阵(Fisher Information Matrix)来度量权重的重要性。Fisher 信息衡量了每个参数对模型输出的敏感度。对于分类任务,通常计算输出概率的对数似然关于权重的二阶导数期望。
在旧任务训练完成后,我们得到最优参数 ( \theta_A^* )。Fisher 信息矩阵 ( F ) 的对角元素 ( F_i ) 近似为:
[ F_i = \mathbb{E}{(x,y)\sim D_A} \left[ \left( \frac{\partial \log p(y|x,\theta)}{\partial \theta_i} \right)^2 \right] \Bigg|{\theta = \theta_A^*} ]
该值越大,说明参数 ( \theta_i ) 的微小变化会对旧任务的输出产生较大影响,因此这个参数对旧任务越重要。为了计算方便,EWC 通常只使用 Fisher 对角元素,因为完整矩阵的计算和存储开销巨大。
EWC 损失函数
在新任务 B 上训练时,EWC 的总损失函数为:
[ \mathcal{L}(\theta) = \mathcal{L}B(\theta) + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta{A,i}^*)^2 ]
- 第一项 ( \mathcal{L}_B(\theta) ) 是新任务 B 的任务损失(如交叉熵)。
- 第二项是 EWC 正则化项:它惩罚当前参数 ( \theta_i ) 与旧任务最优参数 ( \theta_{A,i}^* ) 之间的平方差,并用 Fisher 信息 ( F_i ) 加权。
- 超参数 ( \lambda ) 控制旧任务记忆强度与新任务学习能力之间的平衡。( \lambda ) 越大,模型越倾向于保留旧知识。
从优化角度看,这相当于在新任务损失曲面上叠加了一个各向异性的二次惩罚盆地,盆地中心是旧任务解,曲率由 Fisher 信息决定。那些 Fisher 信息大的方向曲率大,参数被“锚定”得更紧。
多任务 EWC 的扩展
当任务序列不止两个时,EWC 可以自然地扩展。一种朴素的做法是分别为每个旧任务保存一组最优参数和 Fisher 信息,然后把它们的惩罚项累加:
[ \mathcal{L}(\theta) = \mathcal{L}k(\theta) + \frac{\lambda}{2} \sum{j=1}^{k-1} \sum_i F_i^{(j)} (\theta_i - \theta_{j,i}^*)^2 ]
然而,这会导致惩罚项数量随任务数线性增长,存储和计算开销变大。在线 EWC(Online EWC)通过维护一个加权和式的单个 Fisher 矩阵和移动平均参数,将历史信息压缩成一组统计量,大大降低了复杂度,更适合流式任务场景。
EWC 的实现步骤(以 PyTorch 为例)
以下是一个简化的 EWC 训练流程,假设已完成旧任务 Task A 的训练。
1. 计算 Fisher 信息矩阵对角并保存旧参数
def compute_fisher_diag(model, dataloader):
fisher = {}
for name, param in model.named_parameters():
fisher[name] = torch.zeros_like(param)
model.eval()
for x, y in dataloader:
model.zero_grad()
output = model(x)
loss = F.nll_loss(F.log_softmax(output, dim=1), y)
loss.backward()
for name, param in model.named_parameters():
fisher[name] += param.grad.data ** 2 / len(dataloader)
return fisher
注意这里通过对数似然损失的梯度平方求平均来近似对角 Fisher 信息。保存 old_params 为 {name: param.clone()}。
2. 定义 EWC 正则化损失
def ewc_loss(model, old_params, fisher, lam=100):
loss = 0
for name, param in model.named_parameters():
loss += torch.sum(fisher[name] * (param - old_params[name]) ** 2)
return lam * loss / 2
3. 训练新任务时合并损失
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for x, y in new_task_loader:
output = model(x)
task_loss = F.nll_loss(F.log_softmax(output, dim=1), y)
reg_loss = ewc_loss(model, old_params, fisher, lam=100)
total_loss = task_loss + reg_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
参数说明
lam((\lambda)):通常需要通过验证集调参。太小无法防止遗忘,太大会导致新任务学习不充分。- Fisher 估计:可以采用更多数据样本来提高准确性,也可以定期更新(在线 EWC)。
- 批量归一化(BatchNorm)层的处理需要小心,因为其运行统计量也属于旧任务记忆的一部分,训练新任务时一般需冻结或单独管理。
EWC 的优点与局限
优点:
- 原理清晰,具有概率和信息几何的解释。
- 不会回放旧数据,避免隐私和存储问题。
- 实现相对简单,计算开销主要在于 Fisher 矩阵的估计和存储。
局限:
- 仅使用对角 Fisher 忽略了权重之间的协相关,约束球形假设可能不够精确。
- 在多任务序列中,累积的正则化项会让优化空间越来越平坦,新任务的学习能力会逐渐衰减。
- 需要确定任务边界(何时计算并固定 Fisher 和旧参数),不适合任务平滑过渡的场景。
- 超参数 (\lambda) 较敏感,且难以自适应调整。
与其他持续学习方法的比较
- 回放方法(如经验回放):直接存储旧数据或生成伪数据,能取得较好性能,但内存消耗大且存在隐私隐患。EWC 完全不存储数据。
- 突触智能(Synaptic Intelligence,SI):也使用权重的二次惩罚,但重要性度量是基于参数在训练过程中的累积贡献,不依赖 Fisher 信息,计算更轻量。
- 内存感知突触(Memory Aware Synapses,MAS):根据模型输出的变化来估计参数重要性,对无监督任务更友好。
- 参数隔离方法(如 Progressive Networks):为每个任务分配独立参数,完全避免遗忘,但模型规模随任务数线性增长。
EWC 以其优雅的统计基础和扎实的实验效果,成为持续学习领域的基石方法,也是理解和设计更复杂正则化方法的重要出发点。
总结
弹性权重巩固(EWC)通过在损失函数中引入 Fisher 信息加权的参数二次惩罚,有效缓解了神经网络的灾难性遗忘问题。它不需要存储旧数据,只需保存每个任务的最优参数和对应的 Fisher 对角矩阵。EWC 的训练流程清晰,易于实现,适合多任务顺序学习的场景。尽管存在对角近似和超参数敏感等局限,其核心思想启发了众多后续研究,是深入持续学习领域的必备知识。