MLIR:面向多级抽象的编译器中间表示
MLIR:面向多级抽象的编译器中间表示
MLIR(Multi-Level Intermediate Representation)是 LLVM 生态中的新一代编译器基础设施,旨在统一机器学习框架、硬件后端以及领域专用编译器中的不同抽象层次。与传统的扁平化 IR 不同,MLIR 允许在同一程序中同时表示高层操作(如张量运算)和低层操作(如向量化指令),并通过渐进式降级将高层语义一步步转换为可执行代码。
本教程面向对编译器或机器学习框架感兴趣的初学者,从零开始介绍 MLIR 的核心概念,并通过具体示例帮助你理解如何使用 MLIR 表示和变换程序。
1. 为什么需要 MLIR?
现代编译栈面临的挑战日益复杂:
- 机器学习模型需要跨越不同框架(TensorFlow、PyTorch)、不同硬件(CPU、GPU、TPU)和不同精度(FP32、INT8)。
- 传统 IR(如 LLVM IR)单一抽象层级,难以表达高层领域语义(如卷积、矩阵乘法)和硬件特定优化。
- 各类领域专用编译器各自为政,重复造轮子,缺乏共享基础设施。
MLIR 通过以下设计应对这些挑战:
- 多级抽象:在同一 IR 中自然表达从张量图到标量运算的所有层级。
- 方言系统:允许定义领域专用操作和类型,并组合使用。
- 可扩展遍:编译过程以 pass 形式组织,可以混合不同方言的变换。
- 渐进降级:从高级方言逐步转化为低级方言,最终生成 LLVM IR 或目标代码。
2. MLIR 核心概念速览
MLIR 程序的基本组成单元是操作(Operation)。一个 MLIR 模块由一系列操作嵌套在**块(Block)和区域(Region)中构成,操作使用特定方言(Dialect)**定义的类型和特性。
2.1 操作、块与区域
- 操作:MLIR 中最小的语义单元。类似于指令,但可以表示任意的抽象层级。每个操作有一个操作名(如
arith.add),若干操作数,若干结果,以及附属属性(attributes)。 - 块:操作的有序列表,末尾必须有一个终止操作(如
func.return)。块可以拥有参数。 - 区域:一个操作可以包含多个块。区域通常用于表示控制流结构(如循环、条件分支)或函数体。
下面的简单 MLIR 程序片段展示了这些结构:
func.func @simple_add(%arg0: i32, %arg1: i32) -> i32 {
%0 = arith.addi %arg0, %arg1 : i32
func.return %0 : i32
}
func.func是一个操作,表示一个函数定义。它包含一个区域,即函数体。- 函数体是一个块,内含两个操作:
arith.addi和func.return。 arith.addi是来自arith方言的整数加法操作,%arg0、%arg1是块参数。
2.2 方言
方言是 MLIR 的核心扩展机制。每个方言定义了一组操作、类型和属性,用于表达特定领域的语义。常用方言包括:
builtin:内建方言,提供模块、函数、基本类型等。func:函数定义、调用的抽象,独立于具体调用约定。arith:整数和浮点算术运算、比较、扩展/截断等。scf:结构化控制流(for、while、if 等)。affine:面向多面体编译的循环与内存操作。memref:多维数组抽象,表示显式内存。linalg:线性代数泛化操作,用于机器学习运算的降级。tensor:张量抽象,无内存副作用。vector:向量掩码操作与虚拟向量寄存器。llvm:与 LLVM IR 直接映射的方言,用于最终降级。gpu、nvvm、rocdl:GPU 相关方言。
你可以组合使用来自不同方言的操作,编译器 pass 能够将一种方言逐步转换为另一种方言。
3. MLIR 的类型系统
MLIR 拥有丰富的类型系统,类型也由方言定义。常见的类型包括:
- 内置类型:
i1,i8,i32,f32,f64,index,none,string等。 tensor<...>:张量类型,如tensor<2x4xf32>。memref<...>:内存引用类型,如memref<2x4xf32>。与tensor不同,memref涉及具体内存空间,可以被反复读写。vector<...>:虚拟向量类型,如vector<4xf32>。!dialect.type:方言自定义类型,如!gpu.symbol。
类型检查在 MLIR 中是必选的,操作数类型必须与操作期望的约束匹配。
4. 属性与装饰器
操作除了操作数和类型,还能附带属性,用于存储编译期常量信息,如常量值、内存布局、边界等。属性没有 SSA 使用-定义链,不可变。
%c = arith.constant 42 : i32 // 属性值为 42
%shape = memref.alloc(%size) {alignment = 64} : memref<?xf32>
上面的 alignment = 64 是一个命名属性,用于指定内存对齐。
5. MLIR 的 Pass 与编译流程
MLIR 编译器将输入程序经过一系列 pass 变换为输出表示。Pass 是操作在特定粒度上的变换,可以是分析、规范变换、降级、优化等。MLIR 提供多种 pass 基础设施:
- Operation pass:在某个操作上运行,不改变嵌套结构。
- Analysis:提供查询信息。
- Conversion pass:将一组方言的操作转换为另一组方言的操作(方言降级)。
典型的编译管线可能是:
tensor/xla/linalg → linalg → affine/scf → memref/vector → llvm
每一步都是一个或多个 pass 的组合。
你可以通过 mlir-opt 工具指定 pass 管线,观察中间表示的变化。
6. 实践:你的第一个 MLIR 程序
下面通过一个具体例子,展示如何定义一个 MLIR 模块,并使用 mlir-opt 进行变换。
6.1 环境准备
安装 LLVM/MLIR 工具链(通常通过编译 LLVM 源码获得 mlir-opt、mlir-translate 等)。你可以使用官方预构建版本或自行编译。
6.2 编写 MLIR 模块
创建文件 example.mlir,内容如下:
module {
func.func @matmul(%A: tensor<4x8xf32>, %B: tensor<8x4xf32>) -> tensor<4x4xf32> {
%c0 = arith.constant 0.0 : f32
%res = linalg.matmul ins(%A, %B : tensor<4x8xf32>, tensor<8x4xf32>)
outs(%c0 : tensor<4x4xf32>) -> tensor<4x4xf32>
func.return %res : tensor<4x4xf32>
}
}
上述代码定义了一个模块,包含一个函数 matmul,它使用 linalg.matmul 操作表示矩阵乘法。这是一个高级抽象,尚未涉及任何循环或内存。
6.3 使用 Pass 降级
mlir-opt example.mlir --linalg-generalize-named-ops --convert-linalg-to-affine-loops --affine-loop-unroll --convert-scf-to-cf --convert-cf-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --reconcile-unrealized-casts
--linalg-generalize-named-ops:将命名的 linalg 操作(如matmul)泛化为通用的 linalg 泛化形式。--convert-linalg-to-affine-loops:将 linalg 操作降级为仿射循环嵌套。- 随后的 pass 逐步将控制流、函数、算术运算等降级到 LLVM 方言。
每一步运行中间结果,可以只运行部分 pass,并加上 --mlir-print-ir-after-all 观察每个 pass 后的 IR。
6.4 生成 LLVM IR 或目标代码
最终通过 mlir-translate 将 LLVM 方言翻译为 LLVM IR:
mlir-opt example.mlir <pass管线> | mlir-translate --mlir-to-llvmir
然后你可以使用 llc 和 clang 将其编译为可执行文件,或使用 JIT 执行引擎。
7. 自定义方言与操作
MLIR 的强大来源于可扩展性。你可以为你的领域定义新的方言和操作。定义方言通常使用表驱动定义(ODS)和声明式重写规则(DRR),通过 C++ 与 TableGen 描述。
一个简化的自定义方言定义示例(TableGen):
def MyDialect : Dialect {
let name = "my_dialect";
let cppNamespace = "::mlir::my_dialect";
}
def MyOp : Op<MyDialect, "my_op"> {
let arguments = (ins I32:$input);
let results = (outs I32:$output);
let assemblyFormat = "$input attr-dict `:` type($output)";
}
编译后,你可以在 MLIR 中直接使用 %o = my_dialect.my_op %i : i32。配合相应的 C++ 实现和变换 pass,你就可以构建完整的领域编译器。
8. 学习路径建议
- 熟悉基本语法:阅读 MLIR 语言参考,理解 SSA、操作、块、区域、方言。
- 动手实验:使用
mlir-opt对简单的arith、func、scf进行变换。 - 理解方言降级:跟踪一个张量运算从
linalg到LLVM的全过程。 - 阅读源码:研究 MLIR 上游教程示例(位于
mlir/examples/toy),逐步实现自己的方言。 - 深入特定领域:根据兴趣深入
affine、linalg、sparse_tensor等方言的设计与优化模式。
MLIR 生态系统正在快速发展,掌握其多级抽象与渐进降级的思维方式,将使你能够在编译器、高性能计算、AI 框架优化等领域获得强大的设计能力。