MLflow 实战:实验追踪、模型注册与部署

FreeGuideOnline 最新 2026-06-20

MLflow 模型注册与版本:从实验到生产

引言

在机器学习项目中,模型管理常常面临三个核心挑战:模型版本混乱实验脉络不清部署切换困难。MLflow 作为一款开源平台,通过 Model Registry(模型注册表) 提供了一套中心化的模型生命周期管理方案。本教程将带你从零开始,使用 MLflow 完成实验追踪、模型注册、版本控制与阶段管理,最终实现安全的模型部署。

前置准备

开始前请确保环境已就绪:

pip install mlflow scikit-learn pandas numpy
  • Tracking Server:用于记录实验和存储模型。你可直接使用本地文件存储,或启动远程服务:
    mlflow server --host 0.0.0.0 --port 5000 \
        --backend-store-uri sqlite:///mlflow.db \
        --default-artifact-root ./mlruns
    
  • MLflow Tracking URI:若使用远程服务,需设置环境变量:
    import mlflow
    mlflow.set_tracking_uri("http://localhost:5000")
    

第一步:训练并记录基础实验

在做任何注册动作前,我们需要先完成一次实验运行,生成待注册的模型。以下为示例训练代码:

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import mlflow
import mlflow.sklearn

# 数据准备
data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
    data.data, data.target, test_size=0.2, random_state=42)

# 启动 MLflow run
with mlflow.start_run(run_name="iris-rf-baseline") as run:
    # 训练模型
    n_estimators = 100
    max_depth = 5
    model = RandomForestClassifier(
        n_estimators=n_estimators, max_depth=max_depth, random_state=42)
    model.fit(X_train, y_train)

    # 评估并记录指标
    y_pred = model.predict(X_test)
    acc = accuracy_score(y_test, y_pred)

    mlflow.log_param("n_estimators", n_estimators)
    mlflow.log_param("max_depth", max_depth)
    mlflow.log_metric("accuracy", acc)

    # 记录模型(附带依赖环境和输入样例)
    signature = mlflow.models.infer_signature(
        X_train, model.predict(X_train))
    mlflow.sklearn.log_model(
        model, "model", signature=signature, input_example=X_train[:5])

运行上述代码后,在 Tracking UI 中即可看到名为 iris-rf-baseline 的实验记录,并生成一个 run_id

关键点mlflow.sklearn.log_model 会自动保存模型、conda 环境和签名,确保模型的可复现性。

第二步:注册模型到模型注册表

在 MLflow 中,模型注册表 与具体的模型名称(Registered Model)绑定,而非某次运行。你需要将一次运行中的模型“挂载”到一个已注册模型名下。

  • 方式一:通过代码注册
from mlflow.tracking import MlflowClient

client = MlflowClient()
model_uri = f"runs:/{run.info.run_id}/model"
registered_model_name = "IrisClassifier"

# 如果该模型名尚不存在,会自动创建
result = client.create_model_version(
    name=registered_model_name,
    source=model_uri,
    run_id=run.info.run_id
)
print(f"注册模型版本: {result.name} v{result.version}")
  • 方式二:在 UI 中手动注册
    打开 Tracking UI,进入对应 run 的 Artifacts 选项卡,找到 model 文件夹,点击右上角的 “Register Model” 按钮,选择名称或新建即可。

成功注册后,模型版本号会从 1 开始递增,且后续每次注册都会产生新版本。

第三步:管理模型版本

每一次通过 create_model_version 或 UI 注册都会创建一个不可变的版本,其状态包括:Source RunNameVersionCurrent StageDescription 等。

  • 查看所有版本
for mv in client.search_model_versions(f"name='{registered_model_name}'"):
    print(f"版本: {mv.version}, 阶段: {mv.current_stage}, 运行: {mv.run_id}")
  • 添加版本描述(便于团队协作):
client.update_model_version(
    name=registered_model_name,
    version=1,
    description="baseline random forest; accuracy 0.93"
)
  • 根据模型名称获取最新版本或在特定阶段的版本
# 获取最新版本(不考虑阶段)
latest_versions = client.get_latest_versions(registered_model_name)
for v in latest_versions:
    print(v.version, v.current_stage)

# 获取处于生产阶段的最新版本(如无则返回 None)
prod_model = client.get_latest_versions(registered_model_name, stages=["Production"])

第四步:阶段转换——模型生命周期

模型注册表的强大之处在于阶段管理:你可以将某个版本标记为 Staging(预发布)、Production(生产)或 Archived(归档)。阶段转换是可逆的,且能触发自动化流程。

# 将版本1推送到预发布阶段
client.transition_model_version_stage(
    name=registered_model_name,
    version=1,
    stage="Staging",
    archive_existing_versions=False  # 若设为 True,会将原 Staging 中的模型归档
)

# 确认无误后上线到生产
client.transition_model_version_stage(
    name=registered_model_name,
    version=1,
    stage="Production",
    archive_existing_versions=True   # 将旧的生产版本归档
)

阶段转换后,可再次检查状态:

model_v = client.get_model_version(registered_model_name, 1)
print(f"v1 当前阶段: {model_v.current_stage}")

通过严格使用阶段切换,团队成员只需根据阶段名称取用模型,而不必硬编码版本号。

第五步:使用注册模型进行部署

模型注册表的最终目的,是让下游服务能稳定地加载生产模型。MLflow 提供了两种基于 URI 的加载方式:

  • 通过阶段别名加载(推荐):
import mlflow

model_name = "IrisClassifier"
stage = "Production"
model = mlflow.sklearn.load_model(f"models:/{model_name}/{stage}")
prediction = model.predict(X_test)

当模型发生版本变更时,只需在注册表中转换阶段为 Production,代码无需改动。

  • 通过版本号加载:
model = mlflow.sklearn.load_model(f"models:/{model_name}/1")

这种方式适合需要固定版本的场景,例如回溯分析。

  • 部署为 REST API:

MLflow 为 models URI 提供了直接构建镜像和启动服务的命令:

# 导出为 Docker 镜像
mlflow models build-docker -m "models:/IrisClassifier/Production" -n "iris-classifier"
# 启动服务
mlflow models serve -m "models:/IrisClassifier/Production" -p 1234

你就可以通过 POST localhost:1234/invocations 发送 JSON 数据获取预测。

最佳实践

  1. 强制使用描述与标签:为每个模型版本添加说明,例如实验条件、数据集版本。
  2. 自动化阶段转换:通过 CI/CD 脚本,在测试通过后自动将模型提升至 Staging,并通过人工确认进入 Production
  3. 隔离实验与注册表:Tracking Server 与 Registry 的存储可以配置为独立的数据库,以提升可靠性。
  4. 模型签名必须提供:定义输入输出 schema,可有效防止部署时的数据格式错误。
  5. 定期归档:将不再使用的旧版本归档,保持注册表清爽。

总结

通过 MLflow 的模型注册表,你将杂乱无章的模型文件转化为 可追溯、可版本、可阶段切换 的标准化产物。结合实验追踪,团队能够在同一个平台内完成从 “idea → 实验 → 注册 → 上线” 的完整闭环。现在,你可以在自己的项目中实践这套流程,让模型迭代更有序、部署更安全。

下一站:尝试将 MLflow 与 Kubernetes、Spark 或云平台集成,构建企业级 MLOps 管道。