NVIDIA FLARE:面向医疗和工业的联邦学习 SDK

FreeGuideOnline 最新 2026-06-28

bash

创建并激活虚拟环境

python -m venv flare_env source flare_env/bin/activate # Linux/macOS

或 flare_env\Scripts\activate (Windows)

安装 FLARE 稳定版

pip install nvflare

若需医疗扩展功能,可同时安装 MONAI 集成

pip install nvflare[monai]


验证安装:

```bash
nvflare --version
# 应显示类似 nvflare 2.4.0 的信息

安装 POC 模式(快速体验)

FLARE 提供了一个概念验证(Proof of Concept)包,帮助你在单机上模拟联邦学习流程:

pip install nvflare[poc]

后文将详细介绍 POC 的使用方法。

快速开始:5 分钟跑通你的第一个联邦学习

1. 启动 POC 环境

打开终端,执行:

poc

此命令将自动生成一个名为 poc 的工作目录,并在其中创建管理员控制台和两个模拟客户端(site-1、site-2)。

你会看到类似如下的交互式控制台:

Welcome to NVFlare POC Console
Type '?' or 'help' for available commands.
>

2. 提交一个联邦训练作业

在 POC 控制台中,输入:

submit_job job_templates/sag_pt

这将提交一个基于 PyTorch 的联邦平均(FedAvg)作业,任务为在 CIFAR-10 子集上训练一个简单网络。

3. 观察训练进程

作业运行期间,你可以使用以下命令查看状态:

list_jobs
# 显示作业 ID 和状态(RUNNING、FINISHED 等)

view_job <job_id>
# 进入特定作业的详细视图,查看每个客户端的进度

训练完成后,全局模型会被保存在 poc/admin/transfer 目录下。

4. 停止 POC

shutdown

通过这四步,你已经体验了一次完整的联邦学习迭代。接下来我们深入理解 FLARE 的架构。

FLARE 核心概念

联邦学习角色

  • 服务器(Server):协调者,负责聚合客户端上传的模型更新,并分发新的全局模型。
  • 客户端(Client):拥有本地数据的参与方,执行本地训练并上传结果。
  • 管理员(Admin):通过管理控制台提交作业、监控任务状态。

作业(Job)与工作流(Workflow)

FLARE 中一次完整的联邦学习任务被称为作业,它由以下组件打包而成:

  • 工作流配置:定义服务器和客户端之间交互的流程(如 FedAvg 的聚合轮次)。
  • 任务(Task):服务器分配给客户端的计算单元,例如“用当前全局模型训练 1 个 epoch”。
  • 共享资源:如学习率调度器配置、初始模型文件等。

作业提交后,FLARE 的运行时引擎自动将任务分发给所有在线客户端,并处理聚合、容错等底层细节。

控制器(Controller)

控制器是工作流的大脑,运行在服务器端,负责编排整个协作过程。FLARE 提供了多种内置控制器:

  • ScatterAndGather:实现标准 FedAvg 流程。
  • ClientControlledFinalization:用于 FedProx 等客户端驱动的停止条件。
  • CyclicController:用于去中心化环状通信。
  • CrossSiteModelEval:跨站点模型评估专用控制器。

你可以通过组合现有控制器来构建自定义工作流,无需从零开始编写底层通信代码。

过滤器(Filter)

过滤器是一种拦截器机制,在数据流出/流入客户端之前对消息进行处理。常用过滤器包括:

  • PrivateKeyEncryptor / CipherStreamHandler:实现安全加密通信。
  • DifferentialPrivacy:给模型更新添加噪声保护隐私。
  • PercentilePrivacy:通过压缩技术隐藏个体贡献。
  • Quantization:压缩通信量,降低带宽消耗。

过滤器可以在作业配置中极其灵活地组合,无需修改训练代码。

构建你的第一个真实联邦学习作业

以下示例演示如何将本地 PyTorch 训练脚本改造为 FLARE 客户端代码,并编写对应的作业配置。

步骤一:编写客户端训练脚本

FLARE 要求客户端代码实现两个核心接口:

  • train(task_data):执行本地训练,并返回 Shareable 对象(包含模型更新和元数据)。
  • 如果需要进行本地验证,可定义 validate(task_data)

示例(train.py):

import torch
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.fl_trainer import FLTrainer

class MyTrainer(FLTrainer):
    def __init__(self, model, optimizer, criterion, epochs=5):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.epochs = epochs

    def train(self, shareable: Shareable) -> Shareable:
        # 从 shareable 中获取全局模型参数
        params = shareable.get("model_params")
        if params:
            self.model.load_state_dict(params)

        # 执行本地训练(此处简化为标准循环)
        self.model.train()
        for epoch in range(self.epochs):
            for data, target in local_dataloader:
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()

        # 将更新后的模型参数打包返回
        result = Shareable()
        result["model_params"] = self.model.state_dict()
        result["meta"] = {"num_steps": len(local_dataloader) * self.epochs}
        result.set_return_code(ReturnCode.OK)
        return result

保存为 my_trainer.py

步骤二:准备作业配置

在作业目录中创建以下文件:

my_fedavg_job/
├── config_fed_server.json
├── config_fed_client.json
└── meta.json

config_fed_server.json 片段:

{
  "format_version": 2,
  "server": {
    "heart_beat_timeout": 600
  },
  "task_data_filters": [],
  "task_result_filters": [],
  "components": [
    {
      "id": "persistor",
      "path": "nvflare.app_common.model_persistor.FileModelPersistor",
      "args": { "model_dir": "models/server" }
    },
    {
      "id": "shareable_generator",
      "path": "nvflare.app_common.shareablegenerator.FullModelShareableGenerator",
      "args": {}
    },
    {
      "id": "aggregator",
      "path": "nvflare.app_common.aggregator.InTimeAccumulateWeightedAggregator",
      "args": { "expected_data_kind": "WEIGHTS" }
    }
  ],
  "workflows": [
    {
      "id": "scatter_gather_ctl",
      "name": "ScatterAndGather",
      "args": {
        "min_clients": 2,
        "num_rounds": 10,
        "start_round": 0,
        "wait_time_after_min_received": 0,
        "aggregator_id": "aggregator",
        "persistor_id": "persistor",
        "shareable_generator_id": "shareable_generator",
        "train_task_name": "train",
        "train_timeout": 0
      }
    }
  ]
}

config_fed_client.json 片段:

{
  "format_version": 2,
  "executors": [
    {
      "tasks": ["train"],
      "executor": {
        "path": "my_trainer.MyTrainer",
        "args": {
          "model": { "path": "torchvision.models.resnet18" },
          "optimizer": { "path": "torch.optim.SGD", "args": {"lr": 0.01} },
          "criterion": { "path": "torch.nn.CrossEntropyLoss" },
          "epochs": 5
        }
      }
    }
  ],
  "task_result_filters": [],
  "task_data_filters": [],
  "components": []
}

meta.json 示例:

{
  "name": "my_first_fedavg_job",
  "resource_spec": {},
  "min_clients": 2,
  "deploy_map": {
    "server": ["server"],
    "client": ["site-1", "site-2"]
  }
}

步骤三:提交并监控作业

在 POC 控制台中,切换到你的作业目录并提交:

submit_job /path/to/my_fedavg_job

使用 list_jobsview_job 跟踪状态。训练完成后,全局模型保存在服务器配置指定的 models/server 下。

高级特性一览

自定义工作流与控制器继承

如果内置工作流无法满足需求,你可以创建自定义控制器。只需继承 nvflare.apis.Controller 并实现 control_flow() 方法即可。示例(打乱客户端顺序的 FedAvg):

from nvflare.apis.controller import Controller, ClientTask, Task
import random

class ShuffleScatterAndGather(Controller):
    def control_flow(self, abort_signal, fl_ctx):
        for round in range(self._num_rounds):
            clients = self.get_clients()
            random.shuffle(clients)
            task = Task(name="train", data=global_model_shareable)
            self.broadcast_and_wait(task, min_responses=2, fl_ctx=fl_ctx, targets=clients)
            # 聚合逻辑...