近似遗忘:用影响函数和梯度修正模拟数据删除
近似遗忘:用影响函数与梯度修正模拟数据删除
1. 什么是机器遗忘
在机器学习中,机器遗忘(Machine Unlearning) 指从已训练模型中彻底移除特定训练样本的影响。当用户要求删除个人数据时,我们不仅需要从数据库里清除该记录,还应使模型表现得“仿佛从未见过这条数据”。
现实中,直接重新训练模型代价极高,因此需要近似遗忘技术,在不完全重训的前提下,快速更新模型,模拟数据删除的效果。
2. 精确遗忘的挑战
- 完整重训成本过高:现代大模型训练可能需要数周、数百万美元。每次数据删除请求都重新训练完全不现实。
- 增量学习不是遗忘:简单继续训练或微调并不能保证旧数据的影响被消除,模型仍可能保留记忆。
- 法律与合规要求:GDPR、CCPA 等法规赋予用户“被遗忘权”,强制要求在一定时间内完成数据影响的消除。
因此,我们需要一种计算高效、可证明的近似遗忘方法。
3. 影响函数:衡量单条数据的影响力
影响函数(Influence Functions)来自稳健统计学,用于量化改变一个训练样本的权重后,模型参数或预测会如何变化。
3.1 基本思想
假设模型参数 θ 由经验风险最小化得到:
$$ \hat{\theta} = \arg\min_{\theta} \frac{1}{n} \sum_{i=1}^{n} L(z_i, \theta) $$
当我们移除样本 $z_{\text{del}}$ 时,新参数 $\hat{\theta}_{-z}$ 满足:
$$ \hat{\theta}{-z} = \arg\min{\theta} \frac{1}{n-1} \sum_{z_i \neq z_{\text{del}}} L(z_i, \theta) $$
影响函数给出 $\hat{\theta}_{-z}$ 对 $\hat{\theta}$ 的一阶近似,无需重新优化。
3.2 参数影响函数
Koh & Liang (2017) 给出:去掉一个训练样本 $z$ 对参数的影响为
$$ \mathcal{I}{\text{param}}(z) = -H{\hat{\theta}}^{-1} \nabla_{\theta} L(z, \hat{\theta}) $$
其中 $H_{\hat{\theta}}$ 是所有训练数据的 Hessian 矩阵(或有限记忆近似)。
- 解释:这条公式告诉我们,移除 $z$ 后参数应该向哪个方向修正,修正量与损失梯度成正比,并由曲率矩阵 $H^{-1}$ 进行缩放。
3.3 对预测的影响
对于测试点 $z_{\text{test}}$,删除样本 $z$ 对其损失或预测的改变量为:
$$ \mathcal{I}{\text{up,loss}}(z, z{\text{test}}) = -\nabla_{\theta} L(z_{\text{test}}, \hat{\theta})^\top H_{\hat{\theta}}^{-1} \nabla_{\theta} L(z, \hat{\theta}) $$
这就是我们衡量单条数据重要性的核心工具。
4. 利用影响函数实现近似遗忘
有了影响函数,就可以在收到删除请求时,直接修正模型参数,而不必重新训练。
4.1 参数修正公式
假设需要删除一组样本 $D_{\text{del}}$,累积影响为各样本影响的求和(一阶近似):
$$ \theta_{\text{new}} \approx \hat{\theta} + \sum_{z \in D_{\text{del}}} \mathcal{I}_{\text{param}}(z) $$
或者写成:
$$ \theta_{\text{new}} = \hat{\theta} - H_{\hat{\theta}}^{-1} \sum_{z \in D_{\text{del}}} \nabla_{\theta} L(z, \hat{\theta}) $$
4.2 完整流程
- 训练原始模型:获得参数 $\hat{\theta}$ 和 Hessian 的逆矩阵向量积的近似(通常用 LiSSA、共轭梯度等)。
- 收到删除请求:标定需要遗忘的数据点。
- 计算累积影响:对每个删除样本计算梯度,求和,再乘以 $-H^{-1}$。
- 参数修正:将修正量加到原始参数上,得到近似遗忘后的模型。
- 可选检验:通过成员推理攻击或损失差验证遗忘效果。
4.3 优点
- 速度极快:删除只需一次矩阵-向量乘法和梯度计算,远比重训高效。
- 无需原始数据:只需已删除样本的梯度,不影响其他数据。
- 数学可解释:有明确的统计学基础。
4.4 局限性与注意事项
- 一阶近似误差:当删除样本量较大,或损失曲面高度非线性时,近似可能不准。
- Hessian 逆的估算:对大模型,精确 $H^{-1}$ 不可行,需借助近似技术,引入额外误差。
- 凸性假设:影响函数严格建立在凸损失且解唯一的情况下,对非凸深度网络只在局部有效。
- 难以处理连锁效应:批量删除时,样本间的交互作用被忽略。
5. 梯度修正:更轻量的方案
当 Hessian 逆的计算成为瓶颈时,可以进一步简化:仅用梯度信息修正模型。
5.1 Newton 步的退化
影响函数本质上是在参数空间执行一步牛顿修正。如果我们忽略曲率,退化为梯度下降步:
$$ \theta_{\text{new}} = \hat{\theta} + \eta \sum_{z \in D_{\text{del}}} \nabla_{\theta} L(z, \hat{\theta}) $$
注意符号:此处是 沿负梯度方向移动? 实际上,我们要移除该样本的影响,本质上应是减小该样本对损失的贡献,但公式推导是上升方向。为避免混淆,更常见的实际做法是:
$$ \theta_{\text{new}} = \hat{\theta} - \epsilon \cdot \sum_{z \in D_{\text{del}}} \nabla_{\theta} L(z, \hat{\theta}) $$
其中 $\epsilon$ 为很小的步长。这相当于让模型“遗忘”这些数据——即增加这些样本上的损失,推动参数朝损失上升方向移动。具体符号方向需根据实现验证,但核心是反其道而行之。
5.2 与“负梯度”微调的关系
另一种梯度修正方法是对删除样本执行梯度上升,对其他保留样本继续梯度下降,混合进行几个 epoch,称为 遗忘微调(Unlearning Fine-tuning)。这实际上是一种启发式方法,脱离了影响函数的严格框架,但在实践中效果良好。
6. 简化代码示例(PyTorch 伪代码)
下面演示如何对一个简单逻辑回归模型进行影响函数近似遗忘。
import torch
# 假设已有训练好的模型 model,损失函数 loss_fn,全部训练数据 loader
# 原始参数保存在 theta_hat
theta_hat = model.parameters()
# 1. 计算 Hessian 逆向量积的近似(这里直接用伪逆,实际需迭代法)
def hvp(v):
# 返回 H^{-1} v 的近似
pass
# 2. 接收删除点列表 del_loader
delta_theta = 0.0
for x, y in del_loader:
loss = loss_fn(model(x), y)
grad = torch.autograd.grad(loss, model.parameters(), create_graph=True)
delta_theta += grad # 累积梯度向量
# 3. 用 Hessian 逆缩放
correction = hvp(delta_theta) # -H^{-1} * sum_grad
# 4. 修正模型参数
for param, corr in zip(model.parameters(), correction):
param.data += corr.data # 注意符号:影响函数是 -H^{-1}g,所以这里是加
对于梯度修正简化版,直接执行:
for param, grad_sum in zip(model.parameters(), delta_theta):
param.data -= 1e-4 * grad_sum
注意:在实践中,Hessian 逆向量积的计算常用共轭梯度法或 LiSSA 算法,避免显式构建 $n \times n$ 矩阵。
7. 实际应用策略
7.1 何时使用影响函数
- 小批量删除(如单条或几条数据)
- 凸模型或局部强凸(逻辑回归、SVM、浅层神经网络)
- 需要严格近似误差界限的场景
7.2 何时使用梯度修正或微调
- 深度神经网络,Hessian 高度奇异或难近似
- 批量删除请求
- 对速度要求极高,可接受一定残余记忆
7.3 增强遗忘保证
近似遗忘本身不能完全移除数据记忆。若要更强的隐私保证,可结合:
- 差分隐私训练:训练时注入噪声,限制单样本影响。
- 模型剪枝和蒸馏:遗忘后对模型进行蒸馏,抹去残余细节。
- 经验验证:用成员推理攻击检验删除样本是否仍可被识别。
8. 总结
近似遗忘利用 影响函数 或 梯度修正,在无需重新训练的情况下快速模拟数据删除。影响函数通过 Hessian 逆提供更精准的一阶修正,而梯度修正则是一种轻量化替代。
| 方法 | 核心公式 | 精度 | 计算代价 |
|---|---|---|---|
| 影响函数 | $-H^{-1}\nabla L(z)$ | 高(一阶) | 需 Hessian 逆向量积 |
| 梯度修正 | $-\epsilon \nabla L(z)$ | 中等 | 极低 |
| 遗忘微调 | 多步梯度上升+下降 | 视调参 | 若干 epoch |
选择哪种方案取决于可用的计算资源、模型的凸性以及对遗忘精度的法律要求。随着隐私法规的加强,这类技术将成为机器学习系统不可或缺的组件。