Swin Transformer:移动窗口的层级视觉 Transformer

FreeGuideOnline 最新 2026-06-21

输入图像(H×W×3) ↓ Patch Partition(图块分割) → H/4 × W/4 × C ↓ Stage 1:Linear Embedding + Swin Transformer Blocks → H/4 × W/4 × C ↓ Stage 2:Patch Merging(下采样) + Swin Transformer Blocks → H/8 × W/8 × 2C ↓ Stage 3:Patch Merging + Swin Transformer Blocks → H/16 × W/16 × 4C ↓ Stage 4:Patch Merging + Swin Transformer Blocks → H/32 × W/32 × 8C


输出的多尺度特征图可直接用于特征金字塔网络(FPN)或 U-Net 解码器,让 Swin 在检测、分割任务中极其得心应手。

## 核心组件一:Patch Merging —— 优雅的下采样
### 动机:如何合并特征又不丢失信息?
在 CNN 中,下采样通常用步长卷积或池化层实现。Swin 则提出了 **Patch Merging**,其本质是一种无参数、可学习的空间到深度转换,目的是在不引入过多额外开销的前提下,将 2×2 的特征块合并成一个,同时倍乘通道数。

### 工作原理
假设输入特征图尺寸为 `H×W×C`(为简化,假设 H、W 均可被 2 整除)。Patch Merging 的操作可分为三步:

1. **空间重排**:将空间维度按间隔采样,分成四个 `H/2 × W/2 × C` 的子图。实现上直接用切片操作,例如取 `[0::2, 0::2]`、`[1::2, 0::2]`、`[0::2, 1::2]`、`[1::2, 1::2]` 四个位置的特征。
2. **拼接**:在通道维度上拼合四个子图,得到 `H/2 × W/2 × 4C` 的特征图。
3. **线性投影**:用一个 1×1 卷积将通道数从 `4C` 映射为 `2C`,完成下采样和通道翻倍。

```python
# 伪代码示例(简化理解)
x0 = x[:, 0::2, 0::2, :]  # 左上
x1 = x[:, 1::2, 0::2, :]  # 左下
x2 = x[:, 0::2, 1::2, :]  # 右上
x3 = x[:, 1::2, 1::2, :]  # 右下
x = torch.cat([x0, x1, x2, x3], dim=-1)  # H/2, W/2, 4C
x = linear_projection(x)  # 输出 H/2, W/2, 2C

这种设计相比直接使用大步长卷积,更好地保留了邻域像素的相对位置关系,且计算量极低。

核心组件二:Swin Transformer Block —— 移动窗口注意力

基本单位:基于窗口的多头自注意力(W-MSA)

全局自注意力的计算复杂度与序列长度平方成正比。对于高分辨率特征图,这几乎是不可接受的。Swin 的解决方案是将特征图均匀划分为不重叠的窗口,只在每个窗口内部计算自注意力。

若特征图大小为 h×w,每个窗口包含 M×M 个 patch(通常 M=7),则窗口数量为 (h/M)×(w/M)。计算复杂度从 O((hw)^2) 骤降至 O(M^4 * 窗口数),即 O(M^2 hw),当 M 固定时与输入分辨率线性相关。

这种局部窗口注意力虽然高效,却有一个致命缺陷:窗口之间没有信息交互,模型视野被局限在单个窗口内,无法构建全局依赖。

突破隔离:移动窗口注意力(SW-MSA)

为了引入跨窗口连接,Swin 提出了移动窗口(Shifted Window)的策略。两个连续的 Transformer Block 形成一个基本单元:

  • 第一个 Block 使用常规的窗口划分(W-MSA);
  • 第二个 Block 使用移动后的窗口划分(SW-MSA)。

具体做法是:在第二个 Block 前,将特征图沿空间方向循环移位(torch.roll)(M//2, M//2) 像素,然后重新划分窗口。此时窗口边界发生偏移,原本不相邻的 patch 被划入同一窗口,从而实现跨窗口信息融合。

但移位后直接划分窗口会带来两个问题:

  • 窗口数量增多(原为 ceil(h/M) 行,移位后可能产生更多不完整窗口)。
  • 直接对不同尺寸的窗口计算注意力会降低效率。

为此,Swin 使用循环移位 + 掩码注意力的巧妙设计:将所有移位的特征图拼接成与原来数量相同的窗口,对跨窗口边界区域引入“遮罩”注意力,阻断非相邻 patch 之间的交互。这里涉及具体的“相对编码 + 掩码矩阵”,将在进阶教程中展开。

相对位置编码

在每个窗口内计算自注意力时,Swin 添加了可学习的相对位置偏置(Relative Position Bias)B ∈ R^{M² × M²}。对于窗口内的 i 与 j patch,注意力权重计算为:

Attention(Q, K, V) = Softmax(QK^T/√d + B)V

B 由每个 patch 在窗口内的坐标差异决定。与 ViT 中的绝对位置编码不同,相对位置偏置能更好地适应窗口的移动,且参数量极小((2M-1)×(2M-1) 个参数),因为坐标差异范围有限。

层级设计的系统优势

将上述组件组装在一起,Swin Transformer 的层级设计带来了以下系统级优势:

1. 线性计算复杂度与分辨率适应性

窗口注意力使得计算复杂度与输入分辨率成线性关系,这让 Swin 可以轻松处理高分辨率图像(如 800×1333 的检测输入),而 ViT 在同样分辨率下将因内存不足而崩溃。

2. 逐层扩大的感受野

  • Stage 1:窗口尺寸 M(默认7)相对输入图像较小,主要捕获纹理、边缘等极局部特征。
  • Stage 2 ~ 4:随着 Patch Merging 不断下采样,每个 patch 覆盖的原图区域变大,同样的窗口尺寸 M 对应的有效感受野呈指数扩张。在 Stage 4,单个窗口已能覆盖全局大部分区域,实现语义级别的理解。

这种从局部到全局的渐进式抽象,与 CNN 的可视化特征图表现一致,使得 Swin 能够同时胜任细节敏感的密集预测任务和全局语义分类任务。

3. 即插即用的多尺度特征

四个 Stage 输出的特征图尺度恰好形成 {1/4, 1/8, 1/16, 1/32} 的金字塔,与经典的 ResNet-50 完全对齐。这意味着开发者可以将现有检测、分割框架(如 Mask R-CNN、UperNet)中的 ResNet 骨干直接替换为 Swin,无需修改检测头或解码器架构,即可获得显著的精度提升。

实战配置:Swin-T 的层级参数

为了让读者更直观地感受层级设计,以下展示最常用的 Swin-T(Tiny)版本的详细配置参数:

阶段 输出分辨率(输入224²) 通道数 层数(Block 数) 窗口大小 M
Stage 1 56×56 96 2 (1 W-MSA + 1 SW-MSA) × 2 组 = 4 Block 7
Stage 2 28×28 192 2 (1+1) × 2 组 = 4 Block 7
Stage 3 14×14 384 6 (1+1) × 3 组 = 12 Block 7
Stage 4 7×7 768 2 (1+1) × 1 组 = 4 Block 7

可以看到,计算量最大的 Block 被分配在分辨率较低的 Stage 3,而高分辨率的浅层阶段层数较少,这种分配兼顾了效率与表达能力。同时,所有阶段保持窗口大小 M=7 不变,让感受野随下采样自然增长,而无需手动调整窗口尺寸。

代码视角:如何在 PyTorch 中构建层级 Swin

下面用精简伪代码展示层级构建的核心逻辑(基于官方实现简化),帮助你理解模块如何串联。

class SwinTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=4, embed_dim=96, 
                 depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, embed_dim) # H/4, W/4
        self.num_layers = len(depths)

        # 构建各Stage
        self.layers = nn.ModuleList()
        for i in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2**i),
                input_resolution=(img_size // (2**(i+2)), ) * 2,
                depth=depths[i],
                num_heads=num_heads[i],
                window_size=7,
                downsample=PatchMerging if i > 0 else None  # 第一级无下采样
            )
            self.layers.append(layer)

    def forward(self, x):
        x = self.patch_embed(x)
        outputs = []
        for layer in self.layers:
            x = layer(x)
            outputs.append(x)  # 多尺度特征输出
        return outputs