Mixup:基于凸组合的简单而有效的数据增强
FreeGuideOnline
最新
2026-06-21
python import torch import numpy as np
def mixup_data(x, y, alpha=0.4): """ 对输入数据和标签执行 mixup 参数: x: 输入张量 (batch_size, ...) y: 标签张量 (batch_size,) 或 one-hot (batch_size, num_classes) alpha: Beta 分布的参数 返回: mixed_x, y_a, y_b, lam """ batch_size = x.size(0) lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0
index = torch.randperm(batch_size, device=x.device)
mixed_x = lam * x + (1 - lam) * x[index, :]
# 如果 y 不是 one-hot,需要转换为 one-hot 以执行线性混合
if y.ndim == 1:
y_onehot = torch.zeros(batch_size, num_classes, device=y.device)
y_onehot.scatter_(1, y.view(-1, 1), 1)
y_a = y_onehot
y_b = y_onehot[index]
else:
y_a = y
y_b = y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam): """ 计算 mixup 损失: lam * loss(pred, y_a) + (1-lam) * loss(pred, y_b) """ return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
**训练循环中的使用**:
```python
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
# 生成 mixup 数据
inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, alpha=0.4)
optimizer.zero_grad()
outputs = model(inputs)
loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
loss.backward()
optimizer.step()