CSWin Transformer:十字形窗口自注意力机制
什么是 CSWin Transformer?
CSWin Transformer 是一种高效的视觉 Transformer 架构,由微软亚洲研究院在 2021 年提出。它的核心创新在于 十字形窗口自注意力机制(Cross-Shaped Window Self-Attention),通过将传统方形窗口拆分为水平与垂直条纹窗口并行计算,在保持全局建模能力的同时大幅降低计算复杂度,实现了速度与精度的优异平衡。
为什么需要十字形窗口注意力?
- 标准自注意力的瓶颈:全局自注意力计算复杂度与图像分辨率平方成正比(O(N²)),导致高分辨率视觉任务难以落地。
- Swin Transformer 的局限:Swin 使用固定大小的方形窗口进行局部注意力,虽然计算高效,但窗口之间信息交互有限,依赖移位窗口传递信息,感受野增长较慢。
- 十字形窗口的思路:人眼视觉系统中存在水平与垂直方向的敏感神经元。CSWin 受此启发,让注意力在水平和垂直两条正交的细长条纹上分别计算,既能获取全局上下文,又保持线性复杂度。
十字形窗口自注意力机制详解
创新点:双重并行自注意力
CSWin 不直接计算全局注意力,也不局限于方形窗口,而是将多头注意力拆分为水平条纹组和垂直条纹组,两组并行计算后拼接结果。
1. 条纹窗口划分
给定输入特征图,尺寸为 H × W × C。将其在通道维度等分成两个分支:
- 水平分支:将特征图按高度方向分成若干个横贯整个宽度的水平条纹,每个条纹宽度为
sw(sw 为超参数,条纹宽度)。 - 垂直分支:将特征图按宽度方向分成若干个贯穿整个高度的垂直条纹,每个条纹高度为
sw。
2. 条纹内自注意力
在两个分支内,各自在每个条纹内部执行标准自注意力:
- 水平条纹内,每个 token 只关注同一行内的所有 token。
- 垂直条纹内,每个 token 只关注同一列内的所有 token。
由于条纹宽度(或高度)通常设为 sw,远小于整图尺寸,计算复杂度从 O(H²W²) 降为 O(HW × sw)。实际操作中,sw 默认设为 7 或 12。
3. 十字形感受野融合
虽然单个分支的注意力只在一个方向上全局化,但并行之后,一个 token 在水平分支获得了在整行上的大范围依赖,在垂直分支获得了整列上的大范围依赖。两者叠加,等效于每个 token 的注意力范围形成了一个十字形窗口,在全图快速建立起长程连接。
4. 位置编码
CSWin 使用局部增强的位置编码(Locally-Enhanced Positional Encoding, LePE),在每个注意力头的结果上添加一个由深度可分离卷积(深度卷积核为 3×3)生成的局部位置偏差,而不是流行的相对位置偏差表。LePE 更灵活、对分辨率变化更鲁棒。
整体架构设计
CSWin Transformer 沿袭了层次化金字塔结构,由四个阶段堆叠而成,每个阶段包含若干 CSWin Transformer Block。
CSWin Transformer Block
输入
│
├─ LayerNorm
├─ 十字形窗口自注意力(CSW-MSA)
│ ├─ 水平条纹注意力头组
│ └─ 垂直条纹注意力头组
├─ LePE 位置编码
└─ 残差连接
│
├─ LayerNorm
├─ MLP (带 GELU 激活)
└─ 残差连接
│
输出
阶段配置(以 CSWin-T 为例)
| 阶段 | 输出分辨率 | 通道数 | 块数 | 条纹宽度 sw |
|---|---|---|---|---|
| 1 | H/4 × W/4 | 64 | 1 | 1(相当于全局注意力) |
| 2 | H/8 × W/8 | 128 | 2 | 2 |
| 3 | H/16 × W/16 | 256 | 21 | 7 |
| 4 | H/32 × W/32 | 512 | 1 | 7 |
前两个阶段使用较小的 sw 以保留更多细节,深层阶段增大 sw 获取更大感受野。第一阶段 sw=1 时,条纹即单行/单列,等价于全局注意力,但计算量仍可控。
降采样通过卷积 Patch Embedding(4×4 卷积步长 4 → 2×2 卷积步长 2)完成,直接继承自早期 ViT 的分块嵌入与卷积下采样的结合。
十字形自注意力的实现细节(伪代码解析)
class CSWinBlock(nn.Module):
def forward(self, x):
# x: (B, H, W, C)
x_norm = LayerNorm(x)
# 通道分割为水平和垂直两个分支
x_h, x_v = x_norm.chunk(2, dim=-1) # 各 C//2 通道
# 水平分支:reshape 为 (B*num_stripes_h, sw, W, C//2)
# 垂直分支:reshape 为 (B*num_stripes_v, H, sw, C//2)
# 在条纹维度执行多头自注意力
attn_h = horizontal_stripe_attention(x_h, sw, num_heads//2)
attn_v = vertical_stripe_attention(x_v, sw, num_heads//2)
# 拼接两支结果
attn_out = torch.cat([attn_h, attn_v], dim=-1)
# 加上局部位置编码
attn_out = attn_out + self.lepe(x_norm)
# 投影与残差
return x + self.proj(attn_out)
horizontal_stripe_attention:将H切分为H//sw个条纹,每个条纹内执行WindowAttention(width=W),计算后重排回原形状。vertical_stripe_attention:将W切分为W//sw个条纹,每个条纹内执行WindowAttention(height=H)。lepe:通过一个深度可分离卷积DepthwiseConv2d(kernel_size=3, padding=1)处理输入,生成位置偏置矩阵直接加到注意力输出上。
这种方式无需复杂的移位操作,避免了 Swin 中 mask 机制的额外开销。
主要实验结论与效果
在 ImageNet-1K 分类、COCO 目标检测和 ADE20K 语义分割上,CSWin 在当时均达到领先水平。
- 分类 Top-1 准确率:CSWin-T(tiny) 83.3%,CSWin-S 84.2%,CSWin-B 85.0%,CSWin-L 85.7%。在同等 FLOPs 下显著超越 Swin 和 DeiT。
- 目标检测:使用 Mask R-CNN 框架,CSWin-S 达到 50.3 APbox,比 Swin-S 高约 1.1 AP。
- 分割:UPerNet 框架下,CSWin-T 的 mIoU 达 49.3,优于 Swin-T。
特别地,CSWin 在模型更大时优势更明显,显示出十字形注意力设计对大模型容量利用得更充分。
优点与局限速览
✅ 优点
- 高分辨率友好:条纹自注意力复杂度仅为
O(HW × sw),计算量随分辨率线性增长。 - 全局感受野建立快:浅层即可获得整行/整列依赖,无需堆叠大量层数。
- 实现简单:无需移位窗口、无需复杂的 mask 计算,易于工程部署。
- 可扩展性强:调节 sw 即可灵活控制计算量与精度。
⚠️ 局限
- 条纹边界可能产生微弱不连续性:虽然 LePE 有所缓解,但严格的全图连续性不如全局注意力。
- 对极端长宽比图像需注意条纹划分方式,可能需动态调整 sw。
- 工程实现中 stripe re-group 操作可能带来内存访问开销,需优化算子融合。
适用场景与迁移学习建议
CSWin 在以下任务中表现突出:
- 高分辨率图像分类(如医学图像、卫星图像)
- 密集预测任务(检测、分割、姿态估计)
- 需要全局依赖建模的长序列特征分析
迁移到下游任务时,建议:
- 使用 CSWin 官方预训练权重作为初始化。
- 根据输入分辨率适度调整
sw,保持条纹内 token 数在 50~200 之间较好。 - LePE 模块在微调时保留,它对不同分辨率具有良好的泛化能力。
- 如果显存紧张,可只在前几个阶段使用十字形窗口,底层仍用较小方形窗口,这也是官方支持的变体。
延伸阅读与资源
- 论文:CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows (arXiv 2021)
- 官方代码库:microsoft/CSWin-Transformer
- 对比学习:《Swin Transformer》以理解设计差异
- 进阶技巧:将 CSWin 的空间结构部署到时序或三维数据中,可将条纹扩展为“时空管状”注意力。
通过将注意力局限在相互正交的条纹空间中,CSWin Transformer 成功地在计算效率和全局建模能力之间找到了优雅的平衡点,是视觉 Transformer 演进过程中极具启发性的工作。