分割学习 Split Learning:将模型切分为客户端与服务器
FreeGuideOnline
最新
2026-06-28
什么是分割学习
分割学习是一种分布式机器学习范式,将神经网络模型按层切分为两个部分:一部分运行在客户端(如移动设备、边缘节点),另一部分运行在服务器。训练过程中,客户端仅将中间激活值(而非原始数据)发送给服务器,服务器完成剩余计算后回传梯度,双方协作完成模型更新。这种设计在保护数据隐私的同时,降低了客户端的计算和通信负担。
核心工作流程
模型切分与角色分配
模型架构被选定一个切分点,例如神经网络第k层之后。客户端保留前k层(称为客户端模型),服务器持有其余层(服务器模型)。切分点位置可根据客户端算力、网络带宽和隐私需求灵活调整。
前向传播
- 客户端用本地数据执行前向传播,直到切分点,得到中间激活张量(通常称为破碎数据Smashed Data)。
- 客户端仅将中间激活发送给服务器,不共享原始输入或标签。
- 服务器接收激活后,继续前向传播至输出层,计算损失。
反向传播与梯度同步
- 服务器根据损失计算回传的梯度,从输出层传播至切分点。
- 服务器将切分点处的梯度回传给客户端。
- 客户端从切分点继续反向传播,更新本地模型参数。
- 服务器也更新自己的参数。 整个过程客户端与服务器各自独立更新,无需交换模型权重或原始数据。
与传统联邦学习的区别
| 对比维度 | 联邦学习 | 分割学习 |
|---|---|---|
| 模型分布 | 各客户端持有完整模型副本 | 模型被切分为两部分,客户端仅持有部分 |
| 上传内容 | 模型参数或梯度 | 中间层的激活值和梯度 |
| 客户端算力要求 | 需要计算完整前向+反向传播 | 仅计算部分层,负担更轻 |
| 隐私泄漏风险 | 可能从梯度反推原始数据 | 中间激活信息量更少,更易添加噪声 |
| 通信模式 | 并列多轮聚合 | 顺序串行训练(也可并行化) |
分割学习尤其适合客户端资源极度受限且数据高度敏感的场景,例如IoT传感器、医疗可穿戴设备等。
分割学习的三种主要模式
标准分割学习
单个客户端与服务器顺序训练一条数据。客户端完成前向部分,服务器完成后向部分并返回梯度。训练是串行的,效率较低,但实现简单。
并行分割学习
多个客户端共享同一个服务器模型部分。客户端并行执行各自的前向计算,将中间激活发送给服务器;服务器可批量处理多个激活,提升训练吞吐量。这需要协调客户端之间的同步。
分割联邦学习
将分割学习与联邦学习结合:客户端持有部分模型,并可将本地部分模型在多轮训练后进行联邦聚合(例如FedAvg)。这样既能利用分割降低客户端算力,又能通过联邦聚合加速收敛。服务器模型部分则始终集中更新。
动手实现一个微型分割学习示例
以下示例使用PyTorch构建一个简单的MLP模型,演示分割训练的过程。我们模拟一个客户端和一个服务器。
环境准备与模型定义
import torch
import torch.nn as nn
import torch.optim as optim
# 假设总模型为 输入784 → 隐藏128 → 输出10
class ClientModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Linear(784, 128)
def forward(self, x):
return torch.relu(self.net(x)) # 客户端部分输出
class ServerModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Linear(128, 10)
def forward(self, x):
return self.net(x) # 无激活,直接logits
模拟一次分割训练迭代
# 初始化
client = ClientModel()
server = ServerModel()
criterion = nn.CrossEntropyLoss()
client_optim = optim.SGD(client.parameters(), lr=0.01)
server_optim = optim.SGD(server.parameters(), lr=0.01)
# 模拟单条数据
x_client = torch.randn(1, 784) # 客户端本地数据
y_true = torch.randint(0, 10, (1,)) # 真实标签
# 1. 客户端前向
with torch.no_grad():
activation = client(x_client)
# 发送activation至服务器(此处本地模拟)
# 2. 服务器前向与梯度回传
server_optim.zero_grad()
activation_server = activation.detach().requires_grad_(True) # 保留梯度连接
output = server(activation_server)
loss = criterion(output, y_true)
loss.backward()
# 3. 服务器回传梯度至客户端(activation_server.grad)
grad_to_client = activation_server.grad
server_optim.step()
# 4. 客户端根据接收的梯度进行反向传播和更新
client_optim.zero_grad()
activation.backward(grad_to_client) # 将梯度传回客户端前向的计算图
client_optim.step()
print(f"Loss: {loss.item():.4f}")
实际部署时,步骤间的通信需通过网络传输中间激活和梯度。
隐私增强技术整合
由于中间激活仍可能泄露数据特征,分割学习常与以下技术结合:
- 差分隐私:在发送激活或梯度前注入校准噪声,实现严格的隐私保障。
- 安全多方计算/同态加密:对中间激活进行加密传输,服务器在密文下计算,但会显著增加计算开销。
- 激活扰动:直接对激活加扰动或量化,简单有效但可能影响模型精度。
实际部署的通信考量
- 激活压缩:可使用量化、稀疏化或主成分分析降低激活传输量。
- 切分点动态调整:根据当前网络延迟和客户端电量,自动切换切分层,实现自适应分割学习。
- 无标签场景:服务器持有标签是一种常见假设;若无标签,需使用自监督或联邦迁移学习等方法。
主要优势与局限
优势:
- 客户端无需完整模型,计算负担极低。
- 原始数据不出本地,隐私保护天然较强。
- 模型架构与大小对客户端透明,便于服务器管理。
局限:
- 训练速度受客户端顺序处理限制,并行优化复杂。
- 中间激活的隐私风险仍需要额外措施。
- 切分点选择依赖经验,不当切分可能导致收敛缓慢。
典型应用场景
- 医疗影像分析:医院边缘设备仅处理前几层卷积,敏感影像不离开本地。
- 智能手机输入法预测:手机端计算嵌入层,云端完成语言模型推理与更新。
- 工业物联网:传感器节点运行轻量特征提取,中央服务器执行复杂故障检测。
分割学习为隐私敏感、资源受限环境提供了一种实用的分布式训练选择。理解其切分原理与通信模式后,你可针对任务特性设计更高效的混合方案。