视觉 Transformer ViT:将图像分解为 Patch 序列

FreeGuideOnline 最新 2026-06-21

z₀ = [x_class; x¹_pE; x²_pE; ... ; x^N_pE] + E_pos z'ₗ = MSA(LayerNorm(z_{l-1})) + z_{l-1} for l = 1 ... L zₗ = FFN(LayerNorm(z'ₗ)) + z'ₗ y = LayerNorm(zₒ[0]) // 取出 [class] token 对应的输出


其中 `E` 是 Patch 投影矩阵,`E_pos` 是位置编码矩阵,`zₒ[0]` 代表第一层输出的 `[class] token` 表征。

## 从全局表征到类别预测

经过 L 层编码器后,`[class] token` 的输出向量 `y` 已经融合了全图信息。将 `y` 送入一个简单的分类头(一层全连接 + softmax),得到各类别的概率分布。分类头在预训练和微调阶段配置不同:
- **预训练**:通常使用大型有监督数据集(如 ImageNet-21k)或 JFT-300M,分类头为对应类别数。
- **微调**:更换为一个自适应的小分类头,并用相应任务的数据集训练。

## 为什么 Patch 序列能工作?

与 CNN 相比,ViT 的全局感受野从第一层就开始生效,而 CNN 需要通过堆叠层逐渐扩大感受野。这使得 ViT 在处理需要长距离相互依赖的任务(如对象形状、场景布局)时具有天然优势。不过,这种架构也意味着对小数据的归纳偏置更弱,通常需要海量数据预训练(或使用强数据增强、正则化)才能超越 CNN。当数据量足够大时,ViT 可以学习到更通用、更强大的视觉表示。

## 动手实现一个最小 ViT

下面使用 PyTorch 给出一个简化的 ViT 实现框架,重点关注 Patch Embedding 和位置编码的构建。

```python
import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        # 使用卷积层代替线性投影,一步完成切分和嵌入
        self.proj = nn.Conv2d(in_chans, embed_dim, 
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x shape: (B, C, H, W)
        x = self.proj(x)                # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)                # (B, embed_dim, N)
        x = x.transpose(1, 2)           # (B, N, embed_dim)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, 
                 num_classes=1000, embed_dim=768, depth=12, num_heads=12, 
                 mlp_ratio=4., qkv_bias=True, drop_rate=0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.n_patches = self.patch_embed.n_patches

        # 可学习的 [class] token 和位置编码
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.n_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # Transformer 编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=drop_rate, activation='gelu', batch_first=True, norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        # 分类头
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # 初始化权重
        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        # 其他层由 PyTorch 默认初始化,可自行补充

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)                   # (B, N, embed_dim)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)     # (B, 1+N, embed_dim)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.transformer(x)                   # (B, 1+N, embed_dim)

        x = self.norm(x[:, 0])                    # 只取 [class] token
        x = self.head(x)
        return x