MLflow 实战:实验追踪、模型注册与部署
MLflow 模型注册与版本:从实验到生产
引言
在机器学习项目中,模型管理常常面临三个核心挑战:模型版本混乱、实验脉络不清 和部署切换困难。MLflow 作为一款开源平台,通过 Model Registry(模型注册表) 提供了一套中心化的模型生命周期管理方案。本教程将带你从零开始,使用 MLflow 完成实验追踪、模型注册、版本控制与阶段管理,最终实现安全的模型部署。
前置准备
开始前请确保环境已就绪:
pip install mlflow scikit-learn pandas numpy
- Tracking Server:用于记录实验和存储模型。你可直接使用本地文件存储,或启动远程服务:
mlflow server --host 0.0.0.0 --port 5000 \ --backend-store-uri sqlite:///mlflow.db \ --default-artifact-root ./mlruns - MLflow Tracking URI:若使用远程服务,需设置环境变量:
import mlflow mlflow.set_tracking_uri("http://localhost:5000")
第一步:训练并记录基础实验
在做任何注册动作前,我们需要先完成一次实验运行,生成待注册的模型。以下为示例训练代码:
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import mlflow
import mlflow.sklearn
# 数据准备
data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
data.data, data.target, test_size=0.2, random_state=42)
# 启动 MLflow run
with mlflow.start_run(run_name="iris-rf-baseline") as run:
# 训练模型
n_estimators = 100
max_depth = 5
model = RandomForestClassifier(
n_estimators=n_estimators, max_depth=max_depth, random_state=42)
model.fit(X_train, y_train)
# 评估并记录指标
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
mlflow.log_param("n_estimators", n_estimators)
mlflow.log_param("max_depth", max_depth)
mlflow.log_metric("accuracy", acc)
# 记录模型(附带依赖环境和输入样例)
signature = mlflow.models.infer_signature(
X_train, model.predict(X_train))
mlflow.sklearn.log_model(
model, "model", signature=signature, input_example=X_train[:5])
运行上述代码后,在 Tracking UI 中即可看到名为 iris-rf-baseline 的实验记录,并生成一个 run_id。
关键点:
mlflow.sklearn.log_model会自动保存模型、conda 环境和签名,确保模型的可复现性。
第二步:注册模型到模型注册表
在 MLflow 中,模型注册表 与具体的模型名称(Registered Model)绑定,而非某次运行。你需要将一次运行中的模型“挂载”到一个已注册模型名下。
- 方式一:通过代码注册
from mlflow.tracking import MlflowClient
client = MlflowClient()
model_uri = f"runs:/{run.info.run_id}/model"
registered_model_name = "IrisClassifier"
# 如果该模型名尚不存在,会自动创建
result = client.create_model_version(
name=registered_model_name,
source=model_uri,
run_id=run.info.run_id
)
print(f"注册模型版本: {result.name} v{result.version}")
- 方式二:在 UI 中手动注册
打开 Tracking UI,进入对应 run 的Artifacts选项卡,找到model文件夹,点击右上角的 “Register Model” 按钮,选择名称或新建即可。
成功注册后,模型版本号会从 1 开始递增,且后续每次注册都会产生新版本。
第三步:管理模型版本
每一次通过 create_model_version 或 UI 注册都会创建一个不可变的版本,其状态包括:Source Run、Name、Version、Current Stage、Description 等。
- 查看所有版本:
for mv in client.search_model_versions(f"name='{registered_model_name}'"):
print(f"版本: {mv.version}, 阶段: {mv.current_stage}, 运行: {mv.run_id}")
- 添加版本描述(便于团队协作):
client.update_model_version(
name=registered_model_name,
version=1,
description="baseline random forest; accuracy 0.93"
)
- 根据模型名称获取最新版本或在特定阶段的版本:
# 获取最新版本(不考虑阶段)
latest_versions = client.get_latest_versions(registered_model_name)
for v in latest_versions:
print(v.version, v.current_stage)
# 获取处于生产阶段的最新版本(如无则返回 None)
prod_model = client.get_latest_versions(registered_model_name, stages=["Production"])
第四步:阶段转换——模型生命周期
模型注册表的强大之处在于阶段管理:你可以将某个版本标记为 Staging(预发布)、Production(生产)或 Archived(归档)。阶段转换是可逆的,且能触发自动化流程。
# 将版本1推送到预发布阶段
client.transition_model_version_stage(
name=registered_model_name,
version=1,
stage="Staging",
archive_existing_versions=False # 若设为 True,会将原 Staging 中的模型归档
)
# 确认无误后上线到生产
client.transition_model_version_stage(
name=registered_model_name,
version=1,
stage="Production",
archive_existing_versions=True # 将旧的生产版本归档
)
阶段转换后,可再次检查状态:
model_v = client.get_model_version(registered_model_name, 1)
print(f"v1 当前阶段: {model_v.current_stage}")
通过严格使用阶段切换,团队成员只需根据阶段名称取用模型,而不必硬编码版本号。
第五步:使用注册模型进行部署
模型注册表的最终目的,是让下游服务能稳定地加载生产模型。MLflow 提供了两种基于 URI 的加载方式:
- 通过阶段别名加载(推荐):
import mlflow
model_name = "IrisClassifier"
stage = "Production"
model = mlflow.sklearn.load_model(f"models:/{model_name}/{stage}")
prediction = model.predict(X_test)
当模型发生版本变更时,只需在注册表中转换阶段为 Production,代码无需改动。
- 通过版本号加载:
model = mlflow.sklearn.load_model(f"models:/{model_name}/1")
这种方式适合需要固定版本的场景,例如回溯分析。
- 部署为 REST API:
MLflow 为 models URI 提供了直接构建镜像和启动服务的命令:
# 导出为 Docker 镜像
mlflow models build-docker -m "models:/IrisClassifier/Production" -n "iris-classifier"
# 启动服务
mlflow models serve -m "models:/IrisClassifier/Production" -p 1234
你就可以通过 POST localhost:1234/invocations 发送 JSON 数据获取预测。
最佳实践
- 强制使用描述与标签:为每个模型版本添加说明,例如实验条件、数据集版本。
- 自动化阶段转换:通过 CI/CD 脚本,在测试通过后自动将模型提升至
Staging,并通过人工确认进入Production。 - 隔离实验与注册表:Tracking Server 与 Registry 的存储可以配置为独立的数据库,以提升可靠性。
- 模型签名必须提供:定义输入输出 schema,可有效防止部署时的数据格式错误。
- 定期归档:将不再使用的旧版本归档,保持注册表清爽。
总结
通过 MLflow 的模型注册表,你将杂乱无章的模型文件转化为 可追溯、可版本、可阶段切换 的标准化产物。结合实验追踪,团队能够在同一个平台内完成从 “idea → 实验 → 注册 → 上线” 的完整闭环。现在,你可以在自己的项目中实践这套流程,让模型迭代更有序、部署更安全。
下一站:尝试将 MLflow 与 Kubernetes、Spark 或云平台集成,构建企业级 MLOps 管道。