NVIDIA FLARE:面向医疗和工业的联邦学习 SDK
bash
创建并激活虚拟环境
python -m venv flare_env source flare_env/bin/activate # Linux/macOS
或 flare_env\Scripts\activate (Windows)
安装 FLARE 稳定版
pip install nvflare
若需医疗扩展功能,可同时安装 MONAI 集成
pip install nvflare[monai]
验证安装:
```bash
nvflare --version
# 应显示类似 nvflare 2.4.0 的信息
安装 POC 模式(快速体验)
FLARE 提供了一个概念验证(Proof of Concept)包,帮助你在单机上模拟联邦学习流程:
pip install nvflare[poc]
后文将详细介绍 POC 的使用方法。
快速开始:5 分钟跑通你的第一个联邦学习
1. 启动 POC 环境
打开终端,执行:
poc
此命令将自动生成一个名为 poc 的工作目录,并在其中创建管理员控制台和两个模拟客户端(site-1、site-2)。
你会看到类似如下的交互式控制台:
Welcome to NVFlare POC Console
Type '?' or 'help' for available commands.
>
2. 提交一个联邦训练作业
在 POC 控制台中,输入:
submit_job job_templates/sag_pt
这将提交一个基于 PyTorch 的联邦平均(FedAvg)作业,任务为在 CIFAR-10 子集上训练一个简单网络。
3. 观察训练进程
作业运行期间,你可以使用以下命令查看状态:
list_jobs
# 显示作业 ID 和状态(RUNNING、FINISHED 等)
view_job <job_id>
# 进入特定作业的详细视图,查看每个客户端的进度
训练完成后,全局模型会被保存在 poc/admin/transfer 目录下。
4. 停止 POC
shutdown
通过这四步,你已经体验了一次完整的联邦学习迭代。接下来我们深入理解 FLARE 的架构。
FLARE 核心概念
联邦学习角色
- 服务器(Server):协调者,负责聚合客户端上传的模型更新,并分发新的全局模型。
- 客户端(Client):拥有本地数据的参与方,执行本地训练并上传结果。
- 管理员(Admin):通过管理控制台提交作业、监控任务状态。
作业(Job)与工作流(Workflow)
FLARE 中一次完整的联邦学习任务被称为作业,它由以下组件打包而成:
- 工作流配置:定义服务器和客户端之间交互的流程(如 FedAvg 的聚合轮次)。
- 任务(Task):服务器分配给客户端的计算单元,例如“用当前全局模型训练 1 个 epoch”。
- 共享资源:如学习率调度器配置、初始模型文件等。
作业提交后,FLARE 的运行时引擎自动将任务分发给所有在线客户端,并处理聚合、容错等底层细节。
控制器(Controller)
控制器是工作流的大脑,运行在服务器端,负责编排整个协作过程。FLARE 提供了多种内置控制器:
ScatterAndGather:实现标准 FedAvg 流程。ClientControlledFinalization:用于 FedProx 等客户端驱动的停止条件。CyclicController:用于去中心化环状通信。CrossSiteModelEval:跨站点模型评估专用控制器。
你可以通过组合现有控制器来构建自定义工作流,无需从零开始编写底层通信代码。
过滤器(Filter)
过滤器是一种拦截器机制,在数据流出/流入客户端之前对消息进行处理。常用过滤器包括:
PrivateKeyEncryptor/CipherStreamHandler:实现安全加密通信。DifferentialPrivacy:给模型更新添加噪声保护隐私。PercentilePrivacy:通过压缩技术隐藏个体贡献。Quantization:压缩通信量,降低带宽消耗。
过滤器可以在作业配置中极其灵活地组合,无需修改训练代码。
构建你的第一个真实联邦学习作业
以下示例演示如何将本地 PyTorch 训练脚本改造为 FLARE 客户端代码,并编写对应的作业配置。
步骤一:编写客户端训练脚本
FLARE 要求客户端代码实现两个核心接口:
train(task_data):执行本地训练,并返回Shareable对象(包含模型更新和元数据)。- 如果需要进行本地验证,可定义
validate(task_data)。
示例(train.py):
import torch
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.fl_trainer import FLTrainer
class MyTrainer(FLTrainer):
def __init__(self, model, optimizer, criterion, epochs=5):
super().__init__()
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.epochs = epochs
def train(self, shareable: Shareable) -> Shareable:
# 从 shareable 中获取全局模型参数
params = shareable.get("model_params")
if params:
self.model.load_state_dict(params)
# 执行本地训练(此处简化为标准循环)
self.model.train()
for epoch in range(self.epochs):
for data, target in local_dataloader:
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
# 将更新后的模型参数打包返回
result = Shareable()
result["model_params"] = self.model.state_dict()
result["meta"] = {"num_steps": len(local_dataloader) * self.epochs}
result.set_return_code(ReturnCode.OK)
return result
保存为 my_trainer.py。
步骤二:准备作业配置
在作业目录中创建以下文件:
my_fedavg_job/
├── config_fed_server.json
├── config_fed_client.json
└── meta.json
config_fed_server.json 片段:
{
"format_version": 2,
"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"components": [
{
"id": "persistor",
"path": "nvflare.app_common.model_persistor.FileModelPersistor",
"args": { "model_dir": "models/server" }
},
{
"id": "shareable_generator",
"path": "nvflare.app_common.shareablegenerator.FullModelShareableGenerator",
"args": {}
},
{
"id": "aggregator",
"path": "nvflare.app_common.aggregator.InTimeAccumulateWeightedAggregator",
"args": { "expected_data_kind": "WEIGHTS" }
}
],
"workflows": [
{
"id": "scatter_gather_ctl",
"name": "ScatterAndGather",
"args": {
"min_clients": 2,
"num_rounds": 10,
"start_round": 0,
"wait_time_after_min_received": 0,
"aggregator_id": "aggregator",
"persistor_id": "persistor",
"shareable_generator_id": "shareable_generator",
"train_task_name": "train",
"train_timeout": 0
}
}
]
}
config_fed_client.json 片段:
{
"format_version": 2,
"executors": [
{
"tasks": ["train"],
"executor": {
"path": "my_trainer.MyTrainer",
"args": {
"model": { "path": "torchvision.models.resnet18" },
"optimizer": { "path": "torch.optim.SGD", "args": {"lr": 0.01} },
"criterion": { "path": "torch.nn.CrossEntropyLoss" },
"epochs": 5
}
}
}
],
"task_result_filters": [],
"task_data_filters": [],
"components": []
}
meta.json 示例:
{
"name": "my_first_fedavg_job",
"resource_spec": {},
"min_clients": 2,
"deploy_map": {
"server": ["server"],
"client": ["site-1", "site-2"]
}
}
步骤三:提交并监控作业
在 POC 控制台中,切换到你的作业目录并提交:
submit_job /path/to/my_fedavg_job
使用 list_jobs 和 view_job 跟踪状态。训练完成后,全局模型保存在服务器配置指定的 models/server 下。
高级特性一览
自定义工作流与控制器继承
如果内置工作流无法满足需求,你可以创建自定义控制器。只需继承 nvflare.apis.Controller 并实现 control_flow() 方法即可。示例(打乱客户端顺序的 FedAvg):
from nvflare.apis.controller import Controller, ClientTask, Task
import random
class ShuffleScatterAndGather(Controller):
def control_flow(self, abort_signal, fl_ctx):
for round in range(self._num_rounds):
clients = self.get_clients()
random.shuffle(clients)
task = Task(name="train", data=global_model_shareable)
self.broadcast_and_wait(task, min_responses=2, fl_ctx=fl_ctx, targets=clients)
# 聚合逻辑...