MLflow 模型管理:实验跟踪与模型注册
MLflow 模型管理:实验跟踪与模型注册
实验的可复现性与模型的生命周期管理是机器学习工程中极易被忽视却至关重要的环节。MLflow 提供了轻量级的 API 与中心化的服务,让你可以在任意环境中跟踪实验参数、代码版本、指标与产物(模型文件),并通过模型注册中心实现从“实验阶段”到“生产部署”的平滑过渡。本教程将带你一步步掌握这两个核心能力。
实验跟踪:记录每一次尝试
实验跟踪让你能够将每次运行的参数、指标和输出文件(如模型、图表)集中存储,并通过 UI 或 API 进行对比和检索。这是构建可复现实验的基础。
安装与基础设置
pip install mlflow
如果希望使用远程跟踪服务器(如团队共享的 Tracking Server),需要在启动 MLflow 前配置跟踪 URI:
export MLFLOW_TRACKING_URI=http://<your-tracking-server>:5000
对于本地学习,直接使用默认的本地存储即可。
启动 Tracking UI
在终端中执行以下命令,即可在本地启动一个轻量 Web 界面,用于可视化所有记录的数据:
mlflow ui
默认访问地址为 http://127.0.0.1:5000。你将在该页面看到所有实验、运行记录、指标曲线和产物列表。
使用 Python API 记录实验
创建或设置实验
import mlflow
# 创建一个新实验,或获取已有的实验 ID
experiment_name = "demo_experiment"
mlflow.set_experiment(experiment_name)
所有后续的 mlflow.start_run() 都会被关联到这个实验下。
记录参数、指标和模型
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 准备数据
data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
data.data, data.target, test_size=0.2, random_state=42
)
# 开始一个 MLflow run,并记录参数、指标与模型
with mlflow.start_run(run_name="random_forest_v1"):
# 记录超参数
params = {"n_estimators": 100, "max_depth": 5, "random_state": 42}
mlflow.log_params(params)
# 训练模型
model = RandomForestClassifier(**params)
model.fit(X_train, y_train)
# 计算并记录指标
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
mlflow.log_metric("accuracy", acc)
# 记录模型文件(自动保存为 MLflow 模型格式)
mlflow.sklearn.log_model(model, "random_forest_model")
记录额外文件与参数
你可以记录任意类型的文件(如图片、文本、序列化对象)或批量记录参数/指标:
mlflow.log_artifact("confusion_matrix.png")
mlflow.log_metrics({"precision": 0.88, "recall": 0.92})
mlflow.log_dict({"feature_importance": list(model.feature_importances_)}, "feature_importance.json")
自动记录(Auto Logging)
对于常见框架,MLflow 提供了自动记录功能,无需手动插入记录代码。
import mlflow.sklearn
mlflow.sklearn.autolog()
# 开启自动记录后,任何 sklearn 模型训练都会被自动捕捉
with mlflow.start_run():
model = RandomForestClassifier(n_estimators=200)
model.fit(X_train, y_train)
自动记录会捕获模型的超参数、训练过程中的指标(若框架支持,如 XGBoost 的评估集合)以及模型文件。可通过 mlflow.<flavor>.autolog(disable=True) 关闭。
模型注册中心:从实验到生产
模型注册中心(Model Registry)提供集中化的模型生命周期管理,包括版本控制、阶段标注(Staging、Production、Archived)和权限控制(若与企业版集成)。
向注册中心注册模型
1. 使用 MLflow Client 注册已有模型
如果你已经通过 Tracking 记录了大量模型,可以通过 UI 或 API 将某个 run 的模型注册到 Registry。
from mlflow.tracking import MlflowClient
client = MlflowClient()
# 找到想要注册的 run 中记录的模型 URI
# 格式为 "runs:/<run_id>/<artifact_path>"
model_uri = "runs:/a1b2c3d4/random_forest_model"
# 以名称 "IrisRF" 注册模型,不存在会自动创建,并分配一个版本号
result = client.create_model_version(
name="IrisRF",
source=model_uri,
run_id="a1b2c3d4"
)
print(f"注册成功,版本号:{result.version}")
2. 在创建 Run 时直接注册
更简便的方式:在 mlflow.sklearn.log_model 时通过 registered_model_name 参数直接注册。
with mlflow.start_run(run_name="rf_prod_candidate"):
model = RandomForestClassifier(n_estimators=150, max_depth=10)
model.fit(X_train, y_train)
acc = accuracy_score(y_test, model.predict(X_test))
mlflow.log_metric("accuracy", acc)
# 同时记录模型并直接注册到 Registry
mlflow.sklearn.log_model(
model,
"model",
registered_model_name="IrisRF"
)
管理模型版本与阶段
模型注册后,每个版本都可分配一个阶段(Stage)来表明其就绪状态。常用阶段包括:
- None:初始状态,未分配阶段。
- Staging:候选版本,可进行线上验证或 A/B 测试。
- Production:正式生产模型。
- Archived:已归档的旧版本。
client = MlflowClient()
# 将某个版本过渡到 Staging
client.transition_model_version_stage(
name="IrisRF",
version=1,
stage="Staging"
)
# 确认验证后,提升为 Production(注意:同一时间只能有一个 Production 版本)
client.transition_model_version_stage(
name="IrisRF",
version=1,
stage="Production"
)
# 添加描述和标签
client.update_model_version(
name="IrisRF",
version=1,
description="基线随机森林模型,准确率 0.93"
)
消费模型:加载已注册的模型推理
在部署或推理时,无需关心底层存储路径,直接使用 models:/ URI 根据名称和阶段/版本获取模型。
import mlflow.pyfunc
# 加载 Production 阶段的最新版本
model = mlflow.pyfunc.load_model("models:/IrisRF/Production")
# 或者加载指定版本
model_v2 = mlflow.pyfunc.load_model("models:/IrisRF/2")
# 进行推理
prediction = model.predict(X_test[:5])
pyfunc 是一个通用包装器,可以加载任何 MLflow 模型风格的模型,使得消费端与训练框架解耦。
最佳实践与工作流
组织实验与运行命名
- 将不同项目或任务放在独立实验中,例如
recommender_ctr_v2。 - 每次运行使用有意义的
run_name,例如xgboost_lr0.01_depth6,便于在 UI 中快速识别。
记录运行上下文信息
通过 mlflow.set_tags() 或 start_run(tags=...) 记录环境、数据版本或 Git commit hash。
mlflow.set_tag("data_version", "v1.2")
mlflow.set_tag("git_commit", "a3f2b1c")
使用 MLflow Projects 打包可复现运行
将代码与环境依赖打包成 MLflow Project,并与 Tracking 结合,可以一键在任意环境中重现历史实验。
安全与权限(企业场景)
如果你使用的是 MLflow 企业版或 Databricks 托管 MLflow,可以:
- 为模型注册中心设置访问控制列表(ACL)。
- 限制谁能将模型过渡到 Production。
- 审核阶段变更记录。
自动化 CI/CD 集成
在持续集成流水线中,可以:
- 训练后自动将模型注册并标记为
Staging。 - 触发自动化测试(如模型验证、性能测试)。
- 通过 API 将测试通过的版本推至
Production。 - 使用 Webhook 通知下游部署系统。
常见问题排查
Q: 找不到 Tracking URI 导致连接失败
如果忘记设置 MLFLOW_TRACKING_URI,运行 mlflow ui 可能看到空界面。检查是否指向了正确的 Tracking Server。
Q: 模型注册失败,提示“Model already exists”
若使用 client.create_model_version 时名称未提前创建,需要先通过 client.create_registered_model("IrisRF") 创建模型占位,或使用 mlflow.register_model() 自动创建(注意:后者已弃用,建议使用 create_model_version)。
Q: 加载 Production 模型失败
确认至少有一个模型版本处于 Production 阶段。如果没有,URI models:/IrisRF/Production 将无法工作。
通过本教程,你已经掌握了使用 MLflow 进行实验跟踪和模型注册的核心操作。现在,你可以将这套方法融入日常开发流程,让模型的迭代过程变得透明、可追溯,并轻松地将验证合格的模型交付给下游服务。