RAFT:全对场变换的精确光流估计
光流估计:RAFT(全对场变换的精确光流估计)完全指南
光流估计是计算机视觉中的基础任务,目标是为视频中每个像素找到其在下一帧中的运动矢量。RAFT (Recurrent All-Pairs Field Transforms) 在 2020 年由普林斯顿大学团队提出,以其简洁的结构和顶尖的性能迅速成为光流领域的新基准。本教程将带你从零理解 RAFT 的核心思想、网络架构以及如何在实际中使用它。
目录
- 什么是光流?
- RAFT 的诞生背景
- 核心思想:全对场变换与循环更新
- RAFT 网络架构详解
- RAFT 的优势与特性
- 实战:如何使用 RAFT 进行光流估计
- 进阶:在自定义数据上微调 RAFT
- 常见问题与调优建议
- 总结
什么是光流?
光流(Optical Flow)是空间运动物体在观察成像平面上像素运动的瞬时速度。它为每个像素点分配一个二维向量 ((u,v)),描述该点在两帧之间的位移方向及大小。光流广泛应用于视频插帧、动作识别、目标跟踪、自动驾驶和运动补偿等领域。
RAFT 的诞生背景
在 RAFT 出现之前,主流光流方法主要分为两类:
- 基于能量的经典方法:如 Horn–Schunck、Farneback,速度快但精度有限。
- 基于深度学习的回归方法:如 FlowNet、PWC-Net,通过端到端网络直接回归光流场,精度提高但参数量庞大、泛化性常受限于训练数据。
这些方法往往需要在速度和精度间取舍,也容易在遮挡区域或大运动场景失效。RAFT 提出了一种全新范式:在高分辨率特征上构建全对场相关代价,并利用一个轻量级循环单元迭代更新流场。这种设计兼具高精度、强泛化和高效推理的特点。
核心思想:全对场变换与循环更新
RAFT 的全称“Recurrent All-Pairs Field Transforms”揭示了其三大核心组件:
- 全对场相关(All-Pairs Correlation):计算两帧间所有像素特征向量的相似度,而不是局部邻域,从而能捕捉远距离运动关系。
- 循环结构(Recurrent):使用一个基于 GRU 的更新操作器,从零初始流场出发,反复查阅相关代价并逐步优化流场,模拟传统优化方法的迭代过程。
- 场变换(Field Transforms):每次迭代时,利用当前流场对相关代价进行查找,并将查找结果输送给更新器,这种基于流的查找机制使得多尺度信息能有效融合。
这种设计使网络无需大量参数就能拟合复杂光流模式,且训练与推理过程更稳定。
RAFT 网络架构详解
RAFT 的推理流程如下:
- 输入:相邻两帧 (I_1) 和 (I_2)。
- 特征提取:通过共享权重的卷积编码器分别提取 1/8 分辨率的稠密特征图。
- 构建相关金字塔:计算两个特征图所有位置间的点积,得到 4D 相关代价体,并下采样为多尺度金字塔。
- 初始化光流场:全零。
- 循环更新:将当前流场、上下文特征、相关代价查找结果输入 GRU 更新块,预测残差流,迭代多次(例如12次)。
- 上采样:将 1/8 光流通过一个可学习的上采样网络还原到原始分辨率。
下面逐一剖析每个模块。
特征提取器
特征编码器采用标准 ResNet 类结构,但将 stride 设为 8,输出通道数 (D=256)。两个输入帧共享该编码器,得到特征图 (F_1, F_2 \in \mathbb{R}^{H \times W \times D})。
为了后续循环更新,还需要提取一个仅依赖 (I_1) 的上下文特征,通常使用与上述结构类似但稍深一点的网络,输出 (H \times W \times C_{\text{context}}),该特征将在每次迭代中与其它信息拼接送入 GRU。
相关金字塔
4D 相关代价体:对 (F_1) 中的每个位置 (x) 和 (F_2) 中的每个位置 (y),计算归一化内积: [ \mathbf{C}(x, y) = \frac{ \langle F_1(x), F_2(y) \rangle }{ \sqrt{D} } ] 得到尺寸为 (H \times W \times H \times W) 的张量。直接存储和查找在计算上不可行,因此 RAFT 通过池化构建一个多级相关金字塔。
金字塔构建:对最后两个维度(即 (F_2) 的空间维度)反复进行平均池化(例如 kernel=2,stride=2),生成多个尺度:(H \times W \times H \times W)、(H \times W \times H/2 \times W/2)、(H \times W \times H/4 \times W/4) … 通常4层。金字塔越顶层分辨率越低,对应更大的感受野,适合大运动;底层精细,适合精确调整。
循环更新操作器
这是 RAFT 的核心创新。它从初始光流场 (\mathbf{f}_0 = \mathbf{0}) 开始,在每一步 (k) 执行:
- 相关查找:根据当前流场 (\mathbf{f}_k),从 (F_1) 中的位置 x 映射到 (F_2) 中的 (x + \mathbf{f}_k(x))。在相关金字塔的每一层,我们以映射点为中心,采集局部窗口(如半径=4)内的相关值,得到金字塔相关特征。将所有层的查找结果拼接,形成输入给 GRU 的相关特征向量。
- 更新操作器:GRU 的输入由三部分拼接而成:上一步隐藏状态导出的流场特征、金字塔相关特征、以及预先提取的上下文特征。GRU 输出用于估计残差光流 (\Delta \mathbf{f})。
- 流场更新:(\mathbf{f}_{k+1} = \mathbf{f}_k + \Delta \mathbf{f})。
这个迭代过程可反复进行(默认12次),最终输出精化的 1/8 分辨率光流。可视作一个学习到的迭代优化算法,每次迭代都基于之前估计的流场从相关金字塔中汲取新证据。
上采样:由于光流在原始图像分辨率下通常具有平滑且边界清晰的特性,RAFT 采用一个凸组合上采样器。它预测每个低分辨率像素映射到高分辨率的掩码权重,对邻域格点进行加权求和,从而恢复全分辨率光流。
RAFT 的优势与特性
- 极高精度:在 Sintel、KITTI 等基准上长期占据首位,甚至超越许多有监督方法。
- 强泛化能力:训练时仅用合成数据(如 FlyingChairs + FlyingThings),即可在真实视频上表现良好,无需在目标域微调。
- 高效率:虽然含有循环结构,但迭代次数少(12次),单次前传推理约 7 fps(2010年代GPU),远快于许多稠密匹配方法。后续加速版还可实时。
- 结构纯净:没有复杂的解码器或代价卷形状变换,代码易读易复现。
- 灵活输出:可直接嵌入到视频理解下游任务,模型可端到端训练。
实战:如何使用 RAFT 进行光流估计
下面以官方 PyTorch 实现为例,展示从环境搭建到结果可视化的完整流程。
环境准备
# 创建虚拟环境(可选)
conda create -n raft python=3.8
conda activate raft
# 安装依赖
pip install torch torchvision opencv-python matplotlib imageio
git clone https://github.com/princeton-vl/RAFT.git
cd RAFT
如果只想推理,不需要安装额外训练库。主要依赖是 PyTorch (>=1.6) 和 CUDA 工具包。
加载预训练模型
RAFT 提供多个官方权重,如 raft-things.pth(在 FlyingThings 上训练,通用性好)或 raft-sintel.pth(在 Sintel 上微调)。下载并放置在 models/ 目录下。
import torch
from RAFT.core.raft import RAFT
from RAFT.core.utils.utils import InputPadder
model = RAFT(args) # args 可默认构造,也可自定义
model = torch.nn.DataParallel(model)
checkpoint = torch.load('models/raft-things.pth')
model.load_state_dict(checkpoint, strict=False)
model.cuda()
model.eval()
注意参数 args.small=False 使用大模型,如需更低延迟可使用 --small 加载小模型。
推理单对图片
准备两张连续帧图片 frame1.png 和 frame2.png,将它们转换为 [0,255] 的 uint8 数组或归一化 Tensor。
import cv2, numpy as np
from torchvision import transforms
def load_image(path):
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img).permute(2,0,1).float()[None] # 1,3,H,W
return img
image1 = load_image('frame1.png')
image2 = load_image('frame2.png')
# 保证尺寸能被8整除
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
with torch.no_grad():
flow_low, flow_up = model(image1.cuda(), image2.cuda(), iters=20, test_mode=True)
# flow_up 即为最终全分辨率光流 [1,2,H,W]
iters 控制循环迭代次数,默认24或12均可,越高精度稍增但速度下降。test_mode=True 关闭梯度。
可视化光流结果
可使用 flow_vis 工具将光流转换为色轮图(H代表方向,V代表大小):
from RAFT.core.utils.flow_viz import flow_to_image
import matplotlib.pyplot as plt
# 转换为 numpy
flow_np = flow_up[0].permute(1,2,0).cpu().numpy()
flow_color = flow_to_image(flow_np, convert_to_bgr=False) # RGB
plt.imshow(flow_color)
plt.title('RAFT Optical Flow')
plt.axis('off')
plt.show()
你也可以保存为图片或视频。flow_viz 默认使用 HSV 色盘:不同颜色表示运动方向(如下图例),颜色饱和度表示运动大小。

(图示:色轮表示光流方向,白/黑为静止)
进阶:在自定义数据上微调 RAFT
如果你需要在特定场景(如医学影像、红外视频)优化性能,可对 RAFT 进行微调。基本步骤:
- 准备数据集:格式为包含成对图片和
.flo光流文件的文件夹结构(如 FlyingChairs 风格)。 - 配置训练参数:修改
train.py中的参数,指定预训练权重路径等。 - 启动训练:
建议使用较小的学习率(如 1e-5)和少步数,防止过拟合。python train.py --name my_finetune \ --stage chairs \ --restore_ckpt models/raft-things.pth \ --num_steps 50000 \ --batch_size 6 \ --gpus 0 - 验证:可使用
evaluate.py在验证集上测试 EPE 等指标。
注意:RAFT 对训练数据中的大运动范围适应性很好,小规模自定义数据集只需少量训练即可收敛。
常见问题与调优建议
- 输入图像尺寸限制:RAFT 要求尺寸能被 8 整除,若不成则用
InputPadder进行补边(补零或复制边缘),并记录 pad 值以裁剪回原尺寸。 - 内存溢出:全对场相关代价体占显存较大,可通过降低特征分辨率或使用
--small模型缓解。推论时 batch_size 通常为1。 - 遮挡区域光流不准:这是所有方法的通病。RAFT 能通过迭代推断遮挡的一致性,但仍会输出明显噪声。可结合前向后向一致性检查剔除错误估计。
- 实时性需求:原始 RAFT 约 10 fps。可以使用加速版(如 RAFT-Stereo 的优化技巧、TensorRT、ONNX 导出)达到实时。也可减少迭代次数至 4 ~ 8 次,精度损失较小。
- 与 PWC-Net 比较:RAFT 精度显著更高,对剧烈运动更鲁棒,但计算量略大。若硬件受限且任务简单,仍可考虑 PWC-Net。
- 输出光流比例:务必确认预处理中图片的缩放因子。官方实现中图片取值范围保持 [0,255],光流值以像素为单位。若输入被归一化或缩放,输出流需相应反变换。
总结
RAFT 重新定义了光流估计的范式,用全对场相关、多尺度金字塔查找和轻量循环单元实现了高精度与强泛化的统一。无论你是想将其作为视频分析流水线的一环,还是进行学术研究,掌握 RAFT 都能带来显著提升。本教程从原理到代码完整覆盖,现在你就可以动手尝试并应用到自己项目中。
进一步阅读:
- 原论文:RAFT: Recurrent All-Pairs Field Transforms for Optical Flow
- GitHub 仓库:princeton-vl/RAFT
- 相关改进:GMA(Global Motion Aggregation)、FlowFormer、GMFlow 等 Transformer 架构。
开始探索光流世界吧!