模型序列化优化:选择 Protobuf 与 FlatBuffers 加速加载
[field_number + wire_type] [length] [payload]
- **field_number** 和 **wire_type** 被打包成一个 varint(变长整数),通常占用 1-2 字节。
- **length** 同样是 varint,表示后续数据的字节数。
- **payload** 是实际数据,对于数值类型直接以小端序存储,对于嵌套消息则递归编码。
这种设计的优势在于紧凑性和可跳过性——解析器遇到未知字段时可以直接根据 length 跳过,实现向前兼容。
### 在模型序列化中使用 Protobuf
TensorFlow 的 `.pb` 文件和 ONNX 的 `.onnx` 文件本质上都是 Protobuf 序列化结果。以 ONNX 为例:
```python
import onnx
import numpy as np
# 加载 ONNX 模型(底层使用 Protobuf 反序列化)
model = onnx.load("resnet50.onnx")
# 直接访问图结构
graph = model.graph
for node in graph.node[:5]:
print(f"Op: {node.op_type}, Inputs: {node.input}")
对于自定义模型,可以自行定义 proto 文件:
syntax = "proto3";
message TensorProto {
repeated int64 shape = 1;
bytes raw_data = 2; // 原始二进制数据,无需编码
string dtype = 3;
}
message ModelProto {
repeated TensorProto weights = 1;
map<string, string> metadata = 2;
}
编译并使用:
protoc --python_out=. model.proto
import model_pb2
import numpy as np
# 序列化
model_proto = model_pb2.ModelProto()
weight = model_proto.weights.add()
weight.shape.extend([64, 3, 7, 7])
weight.raw_data = np.random.randn(64, 3, 7, 7).astype(np.float32).tobytes()
weight.dtype = "float32"
serialized = model_proto.SerializeToString()
# 反序列化
recovered = model_pb2.ModelProto()
recovered.ParseFromString(serialized)
arr = np.frombuffer(recovered.weights[0].raw_data, dtype=np.float32)
arr = arr.reshape(tuple(recovered.weights[0].shape))
Protobuf 的加载性能特征
Protobuf 的反序列化仍然需要完整解析整个消息。对于大规模模型,这一过程依然耗时,但相比 pickle/JSON 已有显著提升:
| 方案 | 500MB 模型加载耗时 | 内存峰值 |
|---|---|---|
| pickle | ~12s | 2.5x 文件大小 |
| JSON | ~25s | 3x 文件大小 |
| Protobuf | ~4s | 1.8x 文件大小 |
Protobuf 的优势来自于二进制编码的紧凑性和高效的解析器实现,但它仍然需要一次完整的数据复制。
FlatBuffers:零拷贝序列化的极致方案
FlatBuffers 的核心哲学
FlatBuffers 是 Google 为游戏和性能敏感场景设计的序列化库。它的核心卖点是:
无需解析。 数据以特定的二进制布局存储在磁盘上,该布局可以直接映射到内存中,反序列化仅仅是获取一个指向数据的指针。
这意味着使用 FlatBuffers 时,从磁盘读取到内存映射再到数据访问,全程零解析、零拷贝。对于包含巨大权重矩阵的深度学习模型,这一特性尤为关键。
FlatBuffers 的内存布局设计
FlatBuffers 采用一种“从后往前构建”的策略,确保所有偏移量在构建时就写入,读取时无需计算:
[Root offset (4 bytes)] ... [Child data] ... [Parent table] [VTable]
<------------------ 文件末尾 文件起始
每个 table 包含一个 VTable 引用和字段偏移量,读取字段时根据偏移量直接定位到内存地址。VTable 机制同时实现了 schema 演化——新增字段只需在 VTable 末尾追加,旧版读取器可以安全忽略。
在模型加载中应用 FlatBuffers
定义 schema 文件(model.fbs):
table Tensor {
shape: [long];
data: [ubyte]; // 原始浮点数据的字节视图
dtype: string;
}
table Model {
weights: [Tensor];
metadata: [string];
}
root_type Model;
编译并生成代码:
flatc --python model.fbs
使用示例:
import numpy as np
from flatbuffers import Builder
import Model # 由 flatc 生成
# 构建模型文件
builder = Builder(1024 * 1024 * 1024) # 1GB 初始缓冲区
# 准备权重数据
weight_data = np.random.randn(64, 3, 7, 7).astype(np.float32)
data_bytes = weight_data.tobytes()
# FlatBuffers 构建(从后往前)
data_offset = builder.CreateByteVector(data_bytes)
shape = [64, 3, 7, 7]
Model.TensorStartShapeVector(builder, len(shape))
for dim in reversed(shape):
builder.PrependInt64(dim)
shape_offset = builder.EndVector()
dtype_offset = builder.CreateString("float32")
Model.TensorStart(builder)
Model.TensorAddShape(builder, shape_offset)
Model.TensorAddData(builder, data_offset)
Model.TensorAddDtype(builder, dtype_offset)
tensor_offset = Model.TensorEnd(builder)
# 构建根表
Model.ModelStartWeightsVector(builder, 1)
builder.PrependUOffsetTRelative(tensor_offset)
weights_offset = builder.EndVector()
Model.ModelStart(builder)
Model.ModelAddWeights(builder, weights_offset)
model_offset = Model.ModelEnd(builder)
builder.Finish(model_offset)
# 获取序列化字节
serialized = builder.Output()
# 写入文件
with open("model.fb", "wb") as f:
f.write(serialized)
# ---- 零拷贝读取 ----
with open("model.fb", "rb") as f:
buf = f.read()
# 直接从缓冲区访问,无需解析
model = Model.Model.GetRootAs(buf, 0)
tensor = model.Weights(0)
# 直接获取数据的 numpy 视图(零拷贝)
arr = np.frombuffer(
tensor.DataAsNumpy().tobytes(), # 这里实际是 bytes 引用
dtype=np.float32
).reshape(tuple(tensor.ShapeAsNumpy()))
更彻底的零拷贝方案是结合 mmap:
import mmap
with open("model.fb", "r+b") as f:
# 内存映射文件
mm = mmap.mmap(f.fileno(), 0)
model = Model.Model.GetRootAs(mm, 0)
# 所有访问直接从 mmap 区域读取,无需复制到用户空间
tensor = model.Weights(0)
shape = [tensor.Shape(i) for i in range(tensor.ShapeLength())]
FlatBuffers 的加载性能
在零拷贝模式下,FlatBuffers 的加载时间几乎是常数级别,与模型大小无关:
| 方案 | 500MB 模型加载耗时 | 首次访问延迟 |
|---|---|---|
| Protobuf | ~4s | 即时 |
| FlatBuffers (直接读) | ~0.8s | 即时 |
| FlatBuffers (mmap) | ~0.05s | 按需页调入 |
mmap 模式下,0.05 秒仅仅是建立虚拟内存映射的时间。实际数据直到被访问时才由操作系统的缺页中断按需加载,这对大型语言模型的权重加载尤其有价值——如果只使用部分权重(如 LoRA 适配器),则根本不会加载完整模型。
Protobuf 与 FlatBuffers 的权衡对比
选择指南
| 维度 | Protobuf | FlatBuffers |
|---|---|---|
| 加载速度 | 快(需完整解析) | 极快(零解析/零拷贝) |
| 序列化速度 | 快 | 中等(构建需预分配) |
| 文件体积 | 优秀(二进制紧凑编码) | 良好(略大于 Protobuf,因偏移量字段) |
| 内存占用 | 1.5-2x 文件大小 | ~1x(mmap 模式可部分驻留) |
| Schema 演化 | 优秀(字段编号机制) | 优秀(VTable 机制) |
| 生态集成 | 非常广泛(TF、ONNX 原生支持) | 较少(需自行适配) |
| 使用复杂度 | 中等 | 较高(构建顺序有要求) |
| 适合场景 | 通用部署、框架集成 | 冷启动敏感、超大模型、边缘设备 |
性能瓶颈的本质差异
Protobuf 的瓶颈在于反序列化时的内存分配和数据复制。即使使用 Arena 分配器优化内存管理,仍需逐字段解析。对于包含数千个权重张量的模型,解析开销呈线性增长。
FlatBuffers 将数据结构设计与内存布局合二为一,其代价是构建过程较为复杂(需要从叶子节点向根节点反向构建),但这一代价在序列化时一次性支付,换来的是读取端的极致性能。
实际优化案例
场景一:PyTorch 模型的自定义 Protobuf 导出
import torch
import torch.nn as nn
import model_pb2
import numpy as np
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def to_protobuf(self) -> bytes:
proto = model_pb2.ModelProto()
for name, param in self.state_dict().items():
tensor = proto.weights.add()
tensor.name = name
tensor.shape.extend(param.shape)
tensor.raw_data = param.cpu().numpy().tobytes()
tensor.dtype = str(param.dtype)
return proto.SerializeToString()
@classmethod
def from_protobuf(cls, data: bytes):
proto = model_pb2.ModelProto()
proto.ParseFromString(data)
model = cls()
state_dict = {}
for tensor in proto.weights:
arr = np.frombuffer(tensor.raw_data, dtype=np.dtype(tensor.dtype))
arr = arr.reshape(tuple(tensor.shape))
state_dict[tensor.name] = torch.from_numpy(arr)
model.load_state_dict(state_dict)
return model
场景二:大规模语言模型的 FlatBuffers 分片加载
对于 70B 参数的大语言模型(约 140GB),完整加载到内存需要多张 GPU。使用 FlatBuffers + mmap,可以实现按需分片加载:
import mmap
import numpy as np
import torch
class FlatBufferLazyModel:
"""延迟加载的 FlatBuffers 模型封装"""
def __init__(self, path: str):
self.file = open(path, "r+b")
self.mm = mmap.mmap(self.file.fileno(), 0)
self.model = Model.Model.GetRootAs(self.mm, 0)
self._cache = {}
def get_weight(self, index: int) -> torch.Tensor:
if index not in self._cache:
tensor = self.model.Weights(index)
raw = tensor.DataAsNumpy()
# mmap 区域的 bytes,转换为 tensor 时零拷贝
arr = np.frombuffer(raw, dtype=np.float32).copy()
arr = arr.reshape([tensor.Shape(i) for i in range(tensor.ShapeLength())])
self._cache[index] = torch.from_numpy(arr)
return self._cache[index]
def evict(self, index: int):
"""释放缓存,允许 OS 回收对应内存页"""
self._cache.pop(index, None)