Donut:无 OCR 管线直接端到端文档解析

FreeGuideOnline 最新 2026-06-23

bash conda create -n donut python=3.9 -y conda activate donut


安装核心依赖:

```bash
pip install transformers sentencepiece pytorch-lightning
pip install donut-python   # 官方轻量封装包

如需从源码安装以获得最新特性:

git clone https://github.com/clovaai/donut.git
cd donut
pip install -e .

验证安装成功:

from donut import DonutModel
print("Donut 安装成功")

数据准备

Donut 任务通过 JSON 结构定义,训练数据由图像文件与对应的标注 JSON 文件组成。典型目录结构:

dataset/
├── train/
│   ├── metadata.jsonl
│   └── image1.png, image2.jpg ...
└── validation/
    ├── metadata.jsonl
    └── ...

metadata.jsonl 每一行是一个训练样本:

{"file_name": "image1.png", "ground_truth": "{\"total\": \"$9.45\", \"date\": \"2021-09-29\", \"items\": [...]}", "task_name": "receipt"}

要点:

  • ground_truth 必须是严格的合法 JSON 字符串,不要包含多余空格或换行。
  • task_name 用于标识任务,可自由定义,如 "invoice", "form" 等。
  • 图像分辨率建议保持 1280 像素的较长边,避免过度压缩导致文字模糊。

转换已有标注
若已有 OCR 结果和字段,可直接将键值对组装成 JSON 字符串作为目标。简单任务甚至可以完全用模板生成合成数据。

模型训练实战

使用 train.py 脚本或 PyTorch Lightning 进行微调。以下为命令行微调示例:

python train.py \
  --task receipt \                       # 任务名,与 jsonl 中一致
  --data_dir ./dataset \                 # 数据集根目录
  --backbone donut-base \                # 预训练模型 (可选 donut-proto, donut-base)
  --batch_size 1 \                       # 显存较小时保持 1,模型较吃显存
  --lr 3e-5 \
  --max_epochs 30 \
  --num_workers 4 \
  --val_check_interval 0.2

关键参数说明

  • backbonedonut-base 适合通用场景,donut-proto 更轻量。首次运行会自动下载预训练权重。
  • batch_size:单张 GPU 通常只能设为 1,因为 Decoder 序列长度变化大,可用梯度累积弥补。
  • max_epochs:小数据集(数百张)建议 30-50 轮,配合早停防止过拟合。
  • val_check_interval:控制验证频率,避免计算瓶颈。

训练过程会自动打印验证集的 token 准确率及 JSON 有效性指标。训练完成后模型保存在 result/ 目录。

推理与结构化提取

加载训练好的模型并执行单一图片推理:

from donut import DonutModel
from PIL import Image
import torch

model = DonutModel.from_pretrained("result/trained_model")
model.eval()

image = Image.open("test_receipt.jpg").convert("RGB")
output = model.inference(image=image, prompt="<s_receipt>")  # 任务提示符
print(output)

<s_receipt> 是任务特定的起始 token,需与训练时的任务名一致。模型输出为纯文本 JSON 字符串,可直接用 json.loads() 解析。

若需批量处理并考虑性能:

predictions = model.inference_batch(
    images=[img1, img2],
    prompts=["<s_receipt>", "<s_receipt>"],
    return_attentions=False
)

处理输出错误
模型偶尔会生成格式错误的 JSON(如缺引号、多余逗号)。可在后处理中尝试部分修复,或设置更高温度进行多次采样投票。

自定义任务与数据集构建

Donut 的灵活性允许你定义任何结构化提取任务,只需设计好 JSON schema。例如护照信息提取:

{"document_type": "passport", "surname": "DOE", "given_names": "JOHN", "nationality": "USA"}

构建相应数据集后,训练时指定 --task passport,推理提示符用 <s_passport>

合成数据生成
官方提供 SynthDoG 工具,可以从给定的模板(HTML 或网页布局)随机生成文档图像及其 JSON 标注。这是快速获取训练数据的有效方式:

pip install synthdog
synthdog -c config.yaml -o synth_data