视觉 Transformer ViT:将图像分解为 Patch 序列
FreeGuideOnline
最新
2026-06-21
z₀ = [x_class; x¹_pE; x²_pE; ... ; x^N_pE] + E_pos z'ₗ = MSA(LayerNorm(z_{l-1})) + z_{l-1} for l = 1 ... L zₗ = FFN(LayerNorm(z'ₗ)) + z'ₗ y = LayerNorm(zₒ[0]) // 取出 [class] token 对应的输出
其中 `E` 是 Patch 投影矩阵,`E_pos` 是位置编码矩阵,`zₒ[0]` 代表第一层输出的 `[class] token` 表征。
## 从全局表征到类别预测
经过 L 层编码器后,`[class] token` 的输出向量 `y` 已经融合了全图信息。将 `y` 送入一个简单的分类头(一层全连接 + softmax),得到各类别的概率分布。分类头在预训练和微调阶段配置不同:
- **预训练**:通常使用大型有监督数据集(如 ImageNet-21k)或 JFT-300M,分类头为对应类别数。
- **微调**:更换为一个自适应的小分类头,并用相应任务的数据集训练。
## 为什么 Patch 序列能工作?
与 CNN 相比,ViT 的全局感受野从第一层就开始生效,而 CNN 需要通过堆叠层逐渐扩大感受野。这使得 ViT 在处理需要长距离相互依赖的任务(如对象形状、场景布局)时具有天然优势。不过,这种架构也意味着对小数据的归纳偏置更弱,通常需要海量数据预训练(或使用强数据增强、正则化)才能超越 CNN。当数据量足够大时,ViT 可以学习到更通用、更强大的视觉表示。
## 动手实现一个最小 ViT
下面使用 PyTorch 给出一个简化的 ViT 实现框架,重点关注 Patch Embedding 和位置编码的构建。
```python
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# 使用卷积层代替线性投影,一步完成切分和嵌入
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
# x shape: (B, C, H, W)
x = self.proj(x) # (B, embed_dim, H/P, W/P)
x = x.flatten(2) # (B, embed_dim, N)
x = x.transpose(1, 2) # (B, N, embed_dim)
return x
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3,
num_classes=1000, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4., qkv_bias=True, drop_rate=0.):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
self.n_patches = self.patch_embed.n_patches
# 可学习的 [class] token 和位置编码
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.n_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
# Transformer 编码器
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim * mlp_ratio),
dropout=drop_rate, activation='gelu', batch_first=True, norm_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# 分类头
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
# 初始化权重
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
# 其他层由 PyTorch 默认初始化,可自行补充
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x) # (B, N, embed_dim)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # (B, 1+N, embed_dim)
x = x + self.pos_embed
x = self.pos_drop(x)
x = self.transformer(x) # (B, 1+N, embed_dim)
x = self.norm(x[:, 0]) # 只取 [class] token
x = self.head(x)
return x