TFX:TensorFlow 生产级机器学习平台
bash pip install tfx
若需要完整的示例项目,可以克隆官方仓库:
```bash
git clone https://github.com/tensorflow/tfx
cd tfx
pip install -e .
建议在虚拟环境中操作。TFX 的核心依赖包括 TensorFlow、Apache Beam、ML Metadata、TensorFlow Data Validation(TFDV)、TensorFlow Transform(TFT)等,直接安装 tfx 包会一并拉取。
构建你的第一条 TFX 流水线
我们将用一个经典的芝加哥出租车小费预测数据集,演示如何搭建本地流水线。这里使用 LocalDagRunner,在一个进程中按序执行所有组件,适合开发和调试。
1. 准备数据和项目结构
下载 CSV 数据(可来自 TFX 示例数据),创建一个 tfx_pipeline 目录。
tfx_pipeline/
├── data/
│ └── taxi.csv
└── pipeline.py
2. 编写组件定义
在 pipeline.py 中,逐一定义 TFX 组件。首先导入必要库:
import os
from tfx.components import (ExampleGen, StatisticsGen, SchemaGen,
ExampleValidator, Transform, Trainer,
Evaluator, Pusher)
from tfx.orchestration.experimental.local.local_dag_runner import LocalDagRunner
from tfx.proto import example_gen_pb2, trainer_pb2, pusher_pb2
from tfx.types import Channel, standard_artifacts
ExampleGen:输入 CSV 路径。
example_gen = ExampleGen(
input=example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(name='train', pattern='data/taxi*.csv'),
])
)
StatisticsGen 和 SchemaGen:直接使用默认配置,链接前一个组件的输出。
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'])
example_validator = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)
Transform:需要编写一个 preprocessing_fn 函数,定义特征处理。
_transform_module_file = 'taxi_transform.py'
def preprocessing_fn(inputs):
outputs = {}
# 对数值特征进行 Z-score 标准化
for key in ['trip_miles', 'trip_seconds', 'fare']:
outputs[key + '_zscore'] = tft.scale_to_z_score(inputs[key])
# 对类别特征生成词表并转换为索引
for key in ['payment_type', 'company']:
outputs[key + '_idx'] = tft.compute_and_apply_vocabulary(inputs[key])
# 目标标签
outputs['tips'] = inputs['tips']
return outputs
将上述代码保存为 taxi_transform.py,并在 Transform 组件中引用:
transform = Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=_transform_module_file
)
Trainer:需要编写一个 Keras 训练模块 taxi_trainer.py,按约定提供 run_fn。
def run_fn(fn_args):
train_data = fn_args.train_files
eval_data = fn_args.eval_files
# 通过 Transform 后的数据读取
dataset = tf.data.TFRecordDataset(train_data, compression_type="GZIP")
# 构建模型、编译、训练...
model.save(fn_args.serving_model_dir, save_format='tf')
Trainer 组件:
trainer = Trainer(
module_file='taxi_trainer.py',
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=1000),
eval_args=trainer_pb2.EvalArgs(num_steps=500)
)
Evaluator:需要编写评估配置,可简单使用 TFX 提供的 EvalConfig 构建。
eval_config = ... # 定义指标切片等
evaluator = Evaluator(
examples=transform.outputs['transformed_examples'],
model=trainer.outputs['model'],
baseline_model=None,
eval_config=eval_config
)
Pusher:推送到指定目录。
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory='serving_model/taxi_model')
)
)
3. 组装流水线并运行
将组件列表传入 Pipeline,然后用 LocalDagRunner 运行。
from tfx.orchestration import pipeline
components = [
example_gen,
statistics_gen,
schema_gen,
example_validator,
transform,
trainer,
evaluator,
pusher,
]
pipeline_root = 'pipeline_output'
pipeline_name = 'taxi_tip_pipeline'
taxi_pipeline = pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=components,
enable_cache=True,
metadata_connection_config=None # 使用 SQLite 作为元数据存储
)
runner = LocalDagRunner()
runner.run(taxi_pipeline)
执行 python pipeline.py,TFX 会依次运行每个组件,并在控制台输出执行日志。成功运行后,你将在 pipeline_output 目录下看到各组件产生的工件,在 serving_model/ 下找到可部署的模型。
深入理解关键库
TensorFlow Data Validation (TFDV)
TFDV 是 StatisticsGen、SchemaGen 和 ExampleValidator 背后的引擎。你可以单独使用它进行数据探索:
import tensorflow_data_validation as tfdv
stats = tfdv.generate_statistics_from_csv('data.csv')
tfdv.visualize_statistics(stats)