Twins Transformer:空间可分离自注意力的高效 ViT
Twins Transformer 简介
Twins Transformer(Twins‑SVT)是一种专为视觉任务设计的高效 Vision Transformer(ViT)变体。它通过空间可分离自注意力(Spatially Separable Self‑Attention,SSSA) 机制,在保持全局建模能力的同时大幅降低计算成本,尤其适用于高分辨率图像。与标准自注意力对全图所有位置进行两两交互不同,Twins 将注意力分解为局部与全局两个阶段,实现类似卷积的归纳偏置,同时保留 Transformer 的长程依赖捕捉能力。
核心动机:标准自注意力的计算瓶颈
标准自注意力公式如下:
[ \text{Attention}(Q, K, V) = \text{softmax}!\left(\frac{QK^\top}{\sqrt{d}}\right)V ]
对于尺寸为 ( H \times W ) 的特征图,其计算复杂度为 ( O((HW)^2) )。当输入分辨率增大时,计算量和显存占用会急剧膨胀,限制了 ViT 在密集预测任务(如分割、检测)中的应用。Twins 的目标就是设计一种更高效的注意力模式,在精度与效率之间取得平衡。
空间可分离自注意力(SSSA)
SSSA 将一次完整的自注意力拆分成两个连续的步骤:
- 局部空间自注意力(Locally‑grouped Self‑Attention,LSA)
在划分好的局部窗口内计算自注意力,捕捉短距离依赖。 - 全局下采样自注意力(Global Sub‑sampled Self‑Attention,GSA)
对特征图进行稀疏采样,在采样后的少量代表性位置上计算自注意力,实现高效的跨窗口信息交互。
通过这种分解,模型先聚合局部细节,再通过全局摘要进行上下文传播,避免了全图密集交互。
局部空间自注意力(LSA)细节
LSA 将特征图均匀划分为 ( m \times m ) 的局部窗口(典型窗口大小为 7×7 或 14×14)。每个窗口内部独立执行标准自注意力。
计算步骤:
- 输入特征 ( X \in \mathbb{R}^{H \times W \times C} ) 被重塑为 ( (\frac{H}{m} \times \frac{W}{m}, m^2, C) ),即 batch 大小为窗口数。
- 在每个窗口内计算 ( Q, K, V ) 并执行注意力操作。
- 输出恢复至原始空间尺寸,并通过残差连接与原输入相加。
优势:
- 复杂度从 ( O((HW)^2) ) 降至 ( O(HW \cdot m^2) ),且窗口数可随分辨率线性扩展。
- 保留了空间邻近像素间的精细关系,类似于卷积的局部性。
全局下采样自注意力(GSA)细节
GSA 负责建立窗口之间的长程依赖。它通过对特征图进行稀疏采样来减少参与交互的 token 数量。
典型实现方式:
- 跨步采样:将特征图沿空间维度进行步长为 ( s ) 的下采样,得到 ( \frac{H}{s} \times \frac{W}{s} ) 个稀疏 token。
- 在这些稀疏 token 上计算标准全局自注意力。
- 生成的全局特征通过上采样或注意力映射广播回原始空间位置。
另一种等效视角——可分离自注意力: TWins 论文中提出,GSA 可视为先沿宽度方向再沿高度方向(或反之)执行一维自注意力。具体而言:
- 对每一行,将所有列的 token 进行注意力聚合(水平全局交互)。
- 对每一列,将所有行的 token 进行注意力聚合(垂直全局交互)。
这种“宽度‑高度”分离的注意力机制只需要 ( O(HW \times (H+W)) ) 的复杂度,远低于全自注意力的 ( O((HW)^2) )。
效果:
- 以极低计算代价实现了全局感受野。
- 配合 LSA,模型兼具局部细节保持与全局上下文理解的双重能力。
Twins‑SVT 整体架构
Twins Transformer 可作为通用骨干网络,其结构遵循现代分层 Transformer 设计:
- Patch Embedding
将输入图像分割为不重叠的 patch(如 4×4),通过卷积或线性投影转换为 token 序列。 - 多阶段特征提取
包含 4 个阶段,每个阶段由多个 Twins 模块堆叠而成,同时逐步降低空间分辨率、增加通道数。 - 每个 Twins 模块
由 LSA 和 GSA 交替组成,并配备前馈网络(FFN)、层归一化(LayerNorm)和残差连接。 - 任务特定头部
对于分类,在最后阶段后添加全局平均池化和全连接层;对于下游密集任务,可将多阶段特征送入 FPN 等结构。
标准配置示例(Twins‑SVT‑S):
- 阶段 1:输出分辨率 H/4 × W/4,通道数 C = 64。
- 阶段 2:分辨率 H/8 × W/8,C = 128。
- 阶段 3:分辨率 H/16 × W/16,C = 256。
- 阶段 4:分辨率 H/32 × W/32,C = 512。
每个阶段包含不同数量的 LSA‑GSA 组合块。
与其他高效注意力的对比
| 方法 | 注意力范围 | 计算复杂度 | 特点 |
|---|---|---|---|
| 标准 ViT | 全局 | ( O((HW)^2) ) | 高计算量,高分辨率下不可行 |
| Swin Transformer | 移动窗口局部 + 窗口间偏移 | ( O(HW \cdot w^2) ),( w ) 为窗口大小 | 有限视野,需交替窗口偏移 |
| Twins (LSA + GSA) | 局部 + 全局稀疏 | LSA: ( O(HW \cdot m^2) );GSA: ( O(HW \cdot (H+W)) ) | 明确的局部与全局分工,更高效的全局交互 |
Twins 的 GSA 通过行列分离或下采样提供了真正的全局交互,而 Swin 的跨窗口连接仍受窗口偏移限制,因此 Twins 在需要强长程语义理解的任务中更具优势。
实验效果与优势总结
- 分类精度:在 ImageNet‑1K 上,Twins‑SVT 以相近或更低的 FLOPs 超越 Swin Transformer 等同类模型。
- 密集预测:在 COCO 目标检测和 ADE20K 语义分割任务中,Twins 骨干网络比 Swin 获得更高的 mAP 和 mIoU,同时推理速度更快。
- 可扩展性:可通过调整窗口大小、下采样比例等参数灵活适配不同算力预算。
- 无需特殊超参:与 CPVT(条件位置编码)结合使用时,可省略传统位置编码,进一步提升性能。
快速上手指南
安装依赖
推荐在 PyTorch 环境下使用 timm 库或官方实现。以 timm 为例:
pip install timm
加载预训练模型
import timm
# Twins‑SVT‑S 模型,ImageNet‑22k 预训练
model = timm.create_model('twins_svt_small', pretrained=True)
model.eval()
推理示例
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
image = Image.open('example.jpg').convert('RGB')
tensor = transform(image).unsqueeze(0) # 添加 batch 维度
with torch.no_grad():
output = model(tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
作为骨干网络用于检测/分割
许多开源框架(如 MMDetection、Detectron2)已集成 Twins 骨干,只需在配置文件中设置为 backbone=dict(type='TwinsSVT', ...) 即可调用。
总结
Twins Transformer 通过空间可分离自注意力,将局部窗口注意力与高效的全局稀疏注意力优雅结合,在维持 Transformer 全局建模优势的同时,将计算量控制在可行范围内。这一设计使其在众多视觉任务中展现出强大的竞争力,尤其适合需要同时处理高分辨率图像和长程依赖的下游任务。对希望深入了解高效 ViT 设计的读者,Twins 提供了一个清晰、模块化且易于复现的出色范例。