RAFT:全对场变换的精确光流估计

FreeGuideOnline 最新 2026-06-20

光流估计:RAFT(全对场变换的精确光流估计)完全指南

光流估计是计算机视觉中的基础任务,目标是为视频中每个像素找到其在下一帧中的运动矢量。RAFT (Recurrent All-Pairs Field Transforms) 在 2020 年由普林斯顿大学团队提出,以其简洁的结构和顶尖的性能迅速成为光流领域的新基准。本教程将带你从零理解 RAFT 的核心思想、网络架构以及如何在实际中使用它。

目录

  1. 什么是光流?
  2. RAFT 的诞生背景
  3. 核心思想:全对场变换与循环更新
  4. RAFT 网络架构详解
  5. RAFT 的优势与特性
  6. 实战:如何使用 RAFT 进行光流估计
  7. 进阶:在自定义数据上微调 RAFT
  8. 常见问题与调优建议
  9. 总结

什么是光流?

光流(Optical Flow)是空间运动物体在观察成像平面上像素运动的瞬时速度。它为每个像素点分配一个二维向量 ((u,v)),描述该点在两帧之间的位移方向及大小。光流广泛应用于视频插帧、动作识别、目标跟踪、自动驾驶和运动补偿等领域。

RAFT 的诞生背景

在 RAFT 出现之前,主流光流方法主要分为两类:

  • 基于能量的经典方法:如 Horn–Schunck、Farneback,速度快但精度有限。
  • 基于深度学习的回归方法:如 FlowNet、PWC-Net,通过端到端网络直接回归光流场,精度提高但参数量庞大、泛化性常受限于训练数据。

这些方法往往需要在速度和精度间取舍,也容易在遮挡区域或大运动场景失效。RAFT 提出了一种全新范式:在高分辨率特征上构建全对场相关代价,并利用一个轻量级循环单元迭代更新流场。这种设计兼具高精度、强泛化和高效推理的特点。

核心思想:全对场变换与循环更新

RAFT 的全称“Recurrent All-Pairs Field Transforms”揭示了其三大核心组件:

  1. 全对场相关(All-Pairs Correlation):计算两帧间所有像素特征向量的相似度,而不是局部邻域,从而能捕捉远距离运动关系。
  2. 循环结构(Recurrent):使用一个基于 GRU 的更新操作器,从零初始流场出发,反复查阅相关代价并逐步优化流场,模拟传统优化方法的迭代过程。
  3. 场变换(Field Transforms):每次迭代时,利用当前流场对相关代价进行查找,并将查找结果输送给更新器,这种基于流的查找机制使得多尺度信息能有效融合。

这种设计使网络无需大量参数就能拟合复杂光流模式,且训练与推理过程更稳定。

RAFT 网络架构详解

RAFT 的推理流程如下:

  1. 输入:相邻两帧 (I_1) 和 (I_2)。
  2. 特征提取:通过共享权重的卷积编码器分别提取 1/8 分辨率的稠密特征图。
  3. 构建相关金字塔:计算两个特征图所有位置间的点积,得到 4D 相关代价体,并下采样为多尺度金字塔。
  4. 初始化光流场:全零。
  5. 循环更新:将当前流场、上下文特征、相关代价查找结果输入 GRU 更新块,预测残差流,迭代多次(例如12次)。
  6. 上采样:将 1/8 光流通过一个可学习的上采样网络还原到原始分辨率。

下面逐一剖析每个模块。

特征提取器

特征编码器采用标准 ResNet 类结构,但将 stride 设为 8,输出通道数 (D=256)。两个输入帧共享该编码器,得到特征图 (F_1, F_2 \in \mathbb{R}^{H \times W \times D})。

为了后续循环更新,还需要提取一个仅依赖 (I_1) 的上下文特征,通常使用与上述结构类似但稍深一点的网络,输出 (H \times W \times C_{\text{context}}),该特征将在每次迭代中与其它信息拼接送入 GRU。

相关金字塔

4D 相关代价体:对 (F_1) 中的每个位置 (x) 和 (F_2) 中的每个位置 (y),计算归一化内积: [ \mathbf{C}(x, y) = \frac{ \langle F_1(x), F_2(y) \rangle }{ \sqrt{D} } ] 得到尺寸为 (H \times W \times H \times W) 的张量。直接存储和查找在计算上不可行,因此 RAFT 通过池化构建一个多级相关金字塔

金字塔构建:对最后两个维度(即 (F_2) 的空间维度)反复进行平均池化(例如 kernel=2,stride=2),生成多个尺度:(H \times W \times H \times W)、(H \times W \times H/2 \times W/2)、(H \times W \times H/4 \times W/4) … 通常4层。金字塔越顶层分辨率越低,对应更大的感受野,适合大运动;底层精细,适合精确调整。

循环更新操作器

这是 RAFT 的核心创新。它从初始光流场 (\mathbf{f}_0 = \mathbf{0}) 开始,在每一步 (k) 执行:

  1. 相关查找:根据当前流场 (\mathbf{f}_k),从 (F_1) 中的位置 x 映射到 (F_2) 中的 (x + \mathbf{f}_k(x))。在相关金字塔的每一层,我们以映射点为中心,采集局部窗口(如半径=4)内的相关值,得到金字塔相关特征。将所有层的查找结果拼接,形成输入给 GRU 的相关特征向量。
  2. 更新操作器:GRU 的输入由三部分拼接而成:上一步隐藏状态导出的流场特征、金字塔相关特征、以及预先提取的上下文特征。GRU 输出用于估计残差光流 (\Delta \mathbf{f})。
  3. 流场更新:(\mathbf{f}_{k+1} = \mathbf{f}_k + \Delta \mathbf{f})。

这个迭代过程可反复进行(默认12次),最终输出精化的 1/8 分辨率光流。可视作一个学习到的迭代优化算法,每次迭代都基于之前估计的流场从相关金字塔中汲取新证据。

上采样:由于光流在原始图像分辨率下通常具有平滑且边界清晰的特性,RAFT 采用一个凸组合上采样器。它预测每个低分辨率像素映射到高分辨率的掩码权重,对邻域格点进行加权求和,从而恢复全分辨率光流。

RAFT 的优势与特性

  • 极高精度:在 Sintel、KITTI 等基准上长期占据首位,甚至超越许多有监督方法。
  • 强泛化能力:训练时仅用合成数据(如 FlyingChairs + FlyingThings),即可在真实视频上表现良好,无需在目标域微调。
  • 高效率:虽然含有循环结构,但迭代次数少(12次),单次前传推理约 7 fps(2010年代GPU),远快于许多稠密匹配方法。后续加速版还可实时。
  • 结构纯净:没有复杂的解码器或代价卷形状变换,代码易读易复现。
  • 灵活输出:可直接嵌入到视频理解下游任务,模型可端到端训练。

实战:如何使用 RAFT 进行光流估计

下面以官方 PyTorch 实现为例,展示从环境搭建到结果可视化的完整流程。

环境准备

# 创建虚拟环境(可选)
conda create -n raft python=3.8
conda activate raft

# 安装依赖
pip install torch torchvision opencv-python matplotlib imageio
git clone https://github.com/princeton-vl/RAFT.git
cd RAFT

如果只想推理,不需要安装额外训练库。主要依赖是 PyTorch (>=1.6) 和 CUDA 工具包。

加载预训练模型

RAFT 提供多个官方权重,如 raft-things.pth(在 FlyingThings 上训练,通用性好)或 raft-sintel.pth(在 Sintel 上微调)。下载并放置在 models/ 目录下。

import torch
from RAFT.core.raft import RAFT
from RAFT.core.utils.utils import InputPadder

model = RAFT(args)  # args 可默认构造,也可自定义
model = torch.nn.DataParallel(model)
checkpoint = torch.load('models/raft-things.pth')
model.load_state_dict(checkpoint, strict=False)
model.cuda()
model.eval()

注意参数 args.small=False 使用大模型,如需更低延迟可使用 --small 加载小模型。

推理单对图片

准备两张连续帧图片 frame1.pngframe2.png,将它们转换为 [0,255] 的 uint8 数组或归一化 Tensor。

import cv2, numpy as np
from torchvision import transforms

def load_image(path):
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = torch.from_numpy(img).permute(2,0,1).float()[None]  # 1,3,H,W
    return img

image1 = load_image('frame1.png')
image2 = load_image('frame2.png')

# 保证尺寸能被8整除
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)

with torch.no_grad():
    flow_low, flow_up = model(image1.cuda(), image2.cuda(), iters=20, test_mode=True)
# flow_up 即为最终全分辨率光流 [1,2,H,W]

iters 控制循环迭代次数,默认24或12均可,越高精度稍增但速度下降。test_mode=True 关闭梯度。

可视化光流结果

可使用 flow_vis 工具将光流转换为色轮图(H代表方向,V代表大小):

from RAFT.core.utils.flow_viz import flow_to_image
import matplotlib.pyplot as plt

# 转换为 numpy
flow_np = flow_up[0].permute(1,2,0).cpu().numpy()
flow_color = flow_to_image(flow_np, convert_to_bgr=False)  # RGB

plt.imshow(flow_color)
plt.title('RAFT Optical Flow')
plt.axis('off')
plt.show()

你也可以保存为图片或视频。flow_viz 默认使用 HSV 色盘:不同颜色表示运动方向(如下图例),颜色饱和度表示运动大小。

Flow color wheel example
(图示:色轮表示光流方向,白/黑为静止)

进阶:在自定义数据上微调 RAFT

如果你需要在特定场景(如医学影像、红外视频)优化性能,可对 RAFT 进行微调。基本步骤:

  1. 准备数据集:格式为包含成对图片和 .flo 光流文件的文件夹结构(如 FlyingChairs 风格)。
  2. 配置训练参数:修改 train.py 中的参数,指定预训练权重路径等。
  3. 启动训练
    python train.py --name my_finetune \
        --stage chairs \
        --restore_ckpt models/raft-things.pth \
        --num_steps 50000 \
        --batch_size 6 \
        --gpus 0
    
    建议使用较小的学习率(如 1e-5)和少步数,防止过拟合。
  4. 验证:可使用 evaluate.py 在验证集上测试 EPE 等指标。

注意:RAFT 对训练数据中的大运动范围适应性很好,小规模自定义数据集只需少量训练即可收敛。

常见问题与调优建议

  1. 输入图像尺寸限制:RAFT 要求尺寸能被 8 整除,若不成则用 InputPadder 进行补边(补零或复制边缘),并记录 pad 值以裁剪回原尺寸。
  2. 内存溢出:全对场相关代价体占显存较大,可通过降低特征分辨率或使用 --small 模型缓解。推论时 batch_size 通常为1。
  3. 遮挡区域光流不准:这是所有方法的通病。RAFT 能通过迭代推断遮挡的一致性,但仍会输出明显噪声。可结合前向后向一致性检查剔除错误估计。
  4. 实时性需求:原始 RAFT 约 10 fps。可以使用加速版(如 RAFT-Stereo 的优化技巧、TensorRT、ONNX 导出)达到实时。也可减少迭代次数至 4 ~ 8 次,精度损失较小。
  5. 与 PWC-Net 比较:RAFT 精度显著更高,对剧烈运动更鲁棒,但计算量略大。若硬件受限且任务简单,仍可考虑 PWC-Net。
  6. 输出光流比例:务必确认预处理中图片的缩放因子。官方实现中图片取值范围保持 [0,255],光流值以像素为单位。若输入被归一化或缩放,输出流需相应反变换。

总结

RAFT 重新定义了光流估计的范式,用全对场相关、多尺度金字塔查找和轻量循环单元实现了高精度与强泛化的统一。无论你是想将其作为视频分析流水线的一环,还是进行学术研究,掌握 RAFT 都能带来显著提升。本教程从原理到代码完整覆盖,现在你就可以动手尝试并应用到自己项目中。

进一步阅读

开始探索光流世界吧!