TFX:TensorFlow 生产级机器学习平台

FreeGuideOnline 最新 2026-06-20

bash pip install tfx


若需要完整的示例项目,可以克隆官方仓库:

```bash
git clone https://github.com/tensorflow/tfx
cd tfx
pip install -e .

建议在虚拟环境中操作。TFX 的核心依赖包括 TensorFlow、Apache Beam、ML Metadata、TensorFlow Data Validation(TFDV)、TensorFlow Transform(TFT)等,直接安装 tfx 包会一并拉取。

构建你的第一条 TFX 流水线

我们将用一个经典的芝加哥出租车小费预测数据集,演示如何搭建本地流水线。这里使用 LocalDagRunner,在一个进程中按序执行所有组件,适合开发和调试。

1. 准备数据和项目结构

下载 CSV 数据(可来自 TFX 示例数据),创建一个 tfx_pipeline 目录。

tfx_pipeline/
├── data/
│   └── taxi.csv
└── pipeline.py

2. 编写组件定义

pipeline.py 中,逐一定义 TFX 组件。首先导入必要库:

import os
from tfx.components import (ExampleGen, StatisticsGen, SchemaGen, 
                            ExampleValidator, Transform, Trainer, 
                            Evaluator, Pusher)
from tfx.orchestration.experimental.local.local_dag_runner import LocalDagRunner
from tfx.proto import example_gen_pb2, trainer_pb2, pusher_pb2
from tfx.types import Channel, standard_artifacts

ExampleGen:输入 CSV 路径。

example_gen = ExampleGen(
    input=example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='train', pattern='data/taxi*.csv'),
    ])
)

StatisticsGen 和 SchemaGen:直接使用默认配置,链接前一个组件的输出。

statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'])
example_validator = ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema']
)

Transform:需要编写一个 preprocessing_fn 函数,定义特征处理。

_transform_module_file = 'taxi_transform.py'

def preprocessing_fn(inputs):
    outputs = {}
    # 对数值特征进行 Z-score 标准化
    for key in ['trip_miles', 'trip_seconds', 'fare']:
        outputs[key + '_zscore'] = tft.scale_to_z_score(inputs[key])
    # 对类别特征生成词表并转换为索引
    for key in ['payment_type', 'company']:
        outputs[key + '_idx'] = tft.compute_and_apply_vocabulary(inputs[key])
    # 目标标签
    outputs['tips'] = inputs['tips']
    return outputs

将上述代码保存为 taxi_transform.py,并在 Transform 组件中引用:

transform = Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=_transform_module_file
)

Trainer:需要编写一个 Keras 训练模块 taxi_trainer.py,按约定提供 run_fn

def run_fn(fn_args):
    train_data = fn_args.train_files
    eval_data = fn_args.eval_files
    # 通过 Transform 后的数据读取
    dataset = tf.data.TFRecordDataset(train_data, compression_type="GZIP")
    # 构建模型、编译、训练...
    model.save(fn_args.serving_model_dir, save_format='tf')

Trainer 组件:

trainer = Trainer(
    module_file='taxi_trainer.py',
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=1000),
    eval_args=trainer_pb2.EvalArgs(num_steps=500)
)

Evaluator:需要编写评估配置,可简单使用 TFX 提供的 EvalConfig 构建。

eval_config = ... # 定义指标切片等
evaluator = Evaluator(
    examples=transform.outputs['transformed_examples'],
    model=trainer.outputs['model'],
    baseline_model=None,
    eval_config=eval_config
)

Pusher:推送到指定目录。

pusher = Pusher(
    model=trainer.outputs['model'],
    model_blessing=evaluator.outputs['blessing'],
    push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
            base_directory='serving_model/taxi_model')
    )
)

3. 组装流水线并运行

将组件列表传入 Pipeline,然后用 LocalDagRunner 运行。

from tfx.orchestration import pipeline

components = [
    example_gen,
    statistics_gen,
    schema_gen,
    example_validator,
    transform,
    trainer,
    evaluator,
    pusher,
]

pipeline_root = 'pipeline_output'
pipeline_name = 'taxi_tip_pipeline'

taxi_pipeline = pipeline.Pipeline(
    pipeline_name=pipeline_name,
    pipeline_root=pipeline_root,
    components=components,
    enable_cache=True,
    metadata_connection_config=None  # 使用 SQLite 作为元数据存储
)

runner = LocalDagRunner()
runner.run(taxi_pipeline)

执行 python pipeline.py,TFX 会依次运行每个组件,并在控制台输出执行日志。成功运行后,你将在 pipeline_output 目录下看到各组件产生的工件,在 serving_model/ 下找到可部署的模型。

深入理解关键库

TensorFlow Data Validation (TFDV)

TFDV 是 StatisticsGen、SchemaGen 和 ExampleValidator 背后的引擎。你可以单独使用它进行数据探索:

import tensorflow_data_validation as tfdv
stats = tfdv.generate_statistics_from_csv('data.csv')
tfdv.visualize_statistics(stats)