分割学习 Split Learning:将模型切分为客户端与服务器

FreeGuideOnline 最新 2026-06-28

什么是分割学习

分割学习是一种分布式机器学习范式,将神经网络模型按层切分为两个部分:一部分运行在客户端(如移动设备、边缘节点),另一部分运行在服务器。训练过程中,客户端仅将中间激活值(而非原始数据)发送给服务器,服务器完成剩余计算后回传梯度,双方协作完成模型更新。这种设计在保护数据隐私的同时,降低了客户端的计算和通信负担。

核心工作流程

模型切分与角色分配

模型架构被选定一个切分点,例如神经网络第k层之后。客户端保留前k层(称为客户端模型),服务器持有其余层(服务器模型)。切分点位置可根据客户端算力、网络带宽和隐私需求灵活调整。

前向传播

  1. 客户端用本地数据执行前向传播,直到切分点,得到中间激活张量(通常称为破碎数据Smashed Data)。
  2. 客户端仅将中间激活发送给服务器,不共享原始输入或标签
  3. 服务器接收激活后,继续前向传播至输出层,计算损失。

反向传播与梯度同步

  1. 服务器根据损失计算回传的梯度,从输出层传播至切分点。
  2. 服务器将切分点处的梯度回传给客户端。
  3. 客户端从切分点继续反向传播,更新本地模型参数。
  4. 服务器也更新自己的参数。 整个过程客户端与服务器各自独立更新,无需交换模型权重或原始数据。

与传统联邦学习的区别

对比维度 联邦学习 分割学习
模型分布 各客户端持有完整模型副本 模型被切分为两部分,客户端仅持有部分
上传内容 模型参数或梯度 中间层的激活值和梯度
客户端算力要求 需要计算完整前向+反向传播 仅计算部分层,负担更轻
隐私泄漏风险 可能从梯度反推原始数据 中间激活信息量更少,更易添加噪声
通信模式 并列多轮聚合 顺序串行训练(也可并行化)

分割学习尤其适合客户端资源极度受限数据高度敏感的场景,例如IoT传感器、医疗可穿戴设备等。

分割学习的三种主要模式

标准分割学习

单个客户端与服务器顺序训练一条数据。客户端完成前向部分,服务器完成后向部分并返回梯度。训练是串行的,效率较低,但实现简单。

并行分割学习

多个客户端共享同一个服务器模型部分。客户端并行执行各自的前向计算,将中间激活发送给服务器;服务器可批量处理多个激活,提升训练吞吐量。这需要协调客户端之间的同步。

分割联邦学习

将分割学习与联邦学习结合:客户端持有部分模型,并可将本地部分模型在多轮训练后进行联邦聚合(例如FedAvg)。这样既能利用分割降低客户端算力,又能通过联邦聚合加速收敛。服务器模型部分则始终集中更新。

动手实现一个微型分割学习示例

以下示例使用PyTorch构建一个简单的MLP模型,演示分割训练的过程。我们模拟一个客户端和一个服务器。

环境准备与模型定义

import torch
import torch.nn as nn
import torch.optim as optim

# 假设总模型为 输入784 → 隐藏128 → 输出10
class ClientModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Linear(784, 128)
    def forward(self, x):
        return torch.relu(self.net(x))  # 客户端部分输出

class ServerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Linear(128, 10)
    def forward(self, x):
        return self.net(x)  # 无激活,直接logits

模拟一次分割训练迭代

# 初始化
client = ClientModel()
server = ServerModel()
criterion = nn.CrossEntropyLoss()
client_optim = optim.SGD(client.parameters(), lr=0.01)
server_optim = optim.SGD(server.parameters(), lr=0.01)

# 模拟单条数据
x_client = torch.randn(1, 784)        # 客户端本地数据
y_true = torch.randint(0, 10, (1,))   # 真实标签

# 1. 客户端前向
with torch.no_grad():
    activation = client(x_client)
# 发送activation至服务器(此处本地模拟)

# 2. 服务器前向与梯度回传
server_optim.zero_grad()
activation_server = activation.detach().requires_grad_(True)  # 保留梯度连接
output = server(activation_server)
loss = criterion(output, y_true)
loss.backward()

# 3. 服务器回传梯度至客户端(activation_server.grad)
grad_to_client = activation_server.grad
server_optim.step()

# 4. 客户端根据接收的梯度进行反向传播和更新
client_optim.zero_grad()
activation.backward(grad_to_client)   # 将梯度传回客户端前向的计算图
client_optim.step()

print(f"Loss: {loss.item():.4f}")

实际部署时,步骤间的通信需通过网络传输中间激活和梯度。

隐私增强技术整合

由于中间激活仍可能泄露数据特征,分割学习常与以下技术结合:

  • 差分隐私:在发送激活或梯度前注入校准噪声,实现严格的隐私保障。
  • 安全多方计算/同态加密:对中间激活进行加密传输,服务器在密文下计算,但会显著增加计算开销。
  • 激活扰动:直接对激活加扰动或量化,简单有效但可能影响模型精度。

实际部署的通信考量

  • 激活压缩:可使用量化、稀疏化或主成分分析降低激活传输量。
  • 切分点动态调整:根据当前网络延迟和客户端电量,自动切换切分层,实现自适应分割学习。
  • 无标签场景:服务器持有标签是一种常见假设;若无标签,需使用自监督或联邦迁移学习等方法。

主要优势与局限

优势

  • 客户端无需完整模型,计算负担极低。
  • 原始数据不出本地,隐私保护天然较强。
  • 模型架构与大小对客户端透明,便于服务器管理。

局限

  • 训练速度受客户端顺序处理限制,并行优化复杂。
  • 中间激活的隐私风险仍需要额外措施。
  • 切分点选择依赖经验,不当切分可能导致收敛缓慢。

典型应用场景

  • 医疗影像分析:医院边缘设备仅处理前几层卷积,敏感影像不离开本地。
  • 智能手机输入法预测:手机端计算嵌入层,云端完成语言模型推理与更新。
  • 工业物联网:传感器节点运行轻量特征提取,中央服务器执行复杂故障检测。

分割学习为隐私敏感、资源受限环境提供了一种实用的分布式训练选择。理解其切分原理与通信模式后,你可针对任务特性设计更高效的混合方案。