HRNet:保持高分辨率表征的姿态估计主干网络
HRNet 高分辨率网络入门指南
HRNet(High-Resolution Network)是一种在计算机视觉任务中广泛使用的主干网络,特别在人体姿态估计领域取得了里程碑式的成果。与传统网络不断降低分辨率以提取语义特征不同,HRNet 在整个过程中始终保持高分辨率表征,从而获得空间上更精确的特征,非常适合对位置敏感的任务。本教程将带你从原理到实践,全面理解 HRNet 的设计思想与使用方法。
一、为什么需要 HRNet?——从下采样困境到高分辨率保持
大多数经典分类网络(如 ResNet、VGG)采用“编码器-解码器”范式:通过连续的池化或跨步卷积将特征图缩小,获得强语义信息,再通过上采样恢复分辨率。这种模式存在明显缺陷:
- 空间精度丢失:多次下采样导致细粒度位置信息被稀释,不利于像素级任务(关键点定位、语义分割等)。
- 多尺度融合受限:虽然可以通过跳连接(如 U-Net)或空洞卷积减少损失,但高分辨率特征的恢复过程仍然依赖低分辨率特征,原生高分辨率表征不强。
HRNet 的出发点是:能不能从始至终保持一个高分辨率的主干特征,同时并行引入低分辨率分支来增强语义能力? 答案是肯定的,这就是 HRNet 的核心设计。
二、HRNet 整体架构:并行多分辨率子网与重复多尺度融合
HRNet 由多个分辨率不同的子网络并联组成,并不断在子网之间交换信息。以经典的 HRNet-W32 为例,结构如下图所示(文字描述架构):
Stage 1: 高分辨率分支 (1/4 尺度)
Stage 2: 高分辨率分支 (1/4) + 低分辨分支 (1/8)
Stage 3: 高分辨率分支 (1/4) + 低分辨分支 (1/8) + 更低分支 (1/16)
Stage 4: 高分辨率分支 (1/4) + 低分辨分支 (1/8) + 更低分支 (1/16) + 极低分支 (1/32)
2.1 网络启动与主体
- Stem 阶段:两个 stride=2 的 3×3 卷积,将输入图像缩小至原图的 1/4 尺寸(通道数 C 通常为 32 或 48)。
- 阶段 1:由若干个残差单元(Bottleneck)组成,分辨率保持 1/4,通道数从 C 变为 4C(对于 W32,通道数为 32→128)。这一阶段相当于传统网络的开始部分。
- 阶段 2、3、4:逐步添加更低分辨率的子网,并引入多尺度融合模块,让信息在并行分支之间流动。
2.2 并行子网的构建规则
当进入新阶段时,会在当前最低分辨率分支的基础上再通过 stride=2 的卷积生成一条新分支,使其分辨率减半,通道数加倍。因此:
- 阶段 2:两条分支,分辨率 1/4 (通道 4C)、1/8 (通道 8C)
- 阶段 3:三条分支,1/4 (4C)、1/8 (8C)、1/16 (16C)
- 阶段 4:四条分支,1/4 (4C)、1/8 (8C)、1/16 (16C)、1/32 (32C)
每个分支内部由一系列基础残差块(Basic Block 或 Bottleneck)组成,各分支独立进行卷积运算,保持本分支分辨率不变。
三、多尺度融合模块:HRNet 的核心创新点
仅仅拥有多个并行的分辨率分支是不够的,必须让不同分辨率的特征图相互沟通,实现多尺度信息的融合。HRNet 在每段并行卷积之后、进入下一段之前,都会执行一个跨分辨率交换单元 (Exchange Block)。
3.1 融合方式
假设我们要将各个分支的特征“聚合”给某一个目标分支。根据源分支与目标分支的分辨率关系,使用不同操作:
- 同尺度直连:如果源分支与目标分支分辨率相同,不进行变换,直接相加(identity)。
- 高→低融合:使用 stride=2 的 3×3 卷积进行下采样,确保尺寸匹配后相加。
- 低→高融合:先使用 1×1 卷积调整通道数,再通过最近邻上采样将空间尺寸放大至目标尺寸,然后相加。
最终,目标分支的输出是来自所有分支变换后的特征之和。
3.2 融合模块的具体执行过程
以阶段 3 内部的融合为例,包含 3 个输入分支(1/4、1/8、1/16),要对每一个输出分支重复上述聚合操作,共产生 3 个新的特征图,分别作为下一段并行残差块的输入。这保证了每一个分辨率分支都继承了全部分辨率的信息。
四、HRNet 变体与网络配置
HRNet 按照通道数规模分为多个版本,命名规则为 HRNet-Wx,其中 x 表示 Stage 4 最高分辨率分支的通道数。常用版本:
| 版本 | 高分辨率分支通道数 | 参数量(分类头不计) | 适用场景 |
|---|---|---|---|
| HRNet-W18 | 18 | ~21M | 轻量级任务 |
| HRNet-W32 | 32 | ~41M | 姿态估计经典配置 |
| HRNet-W48 | 48 | ~63M | 追求极致精度 |
更高分辨率的输出通常直接采用最后一个阶段的高分辨率分支特征。对于姿态估计任务,会在该特征图上使用 1×1 卷积回归关键点热图;对于分割任务,则可将其作为原始分辨率的分割头输入。
五、HRNet vs 传统网络的直观对比
| 对比维度 | 传统编码器-解码器(如 U-Net) | HRNet |
|---|---|---|
| 高分辨率特征 | 通过上采样恢复,前期丢失 | 全过程保持,原生高分辨率 |
| 多尺度信息 | 跳连接从编码器拷贝 | 并行分支间反复融合 |
| 语义强度 | 靠深层低分辨率特征 | 通过低分辨率分支注入 |
| 空间精度 | 上采样导致网格效应或模糊 | 原生高分辨率,边缘更清晰 |
| 适用任务 | 分割、恢复 | 位置敏感任务(姿态、分割) |
六、PyTorch 代码实现示例
以下展示一个简化版的 HRNet 核心搭建思路(以两条分支为例,便于理解),完整代码可参考官方实现。
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
# 基础残差块,保持尺寸不变
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if in_channels != out_channels:
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
else:
self.skip = nn.Identity()
def forward(self, x):
residual = self.skip(x)
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += residual
return self.relu(out)
class HRNetStage(nn.Module):
# 一个阶段包含多分辨率并行残差块和多尺度融合
def __init__(self, num_branches, block_per_branch, in_channels_list, out_channels_list):
super().__init__()
self.num_branches = num_branches
self.branches = nn.ModuleList()
for i in range(num_branches):
branch = nn.Sequential(
*[BasicBlock(in_channels_list[i] if j==0 else out_channels_list[i],
out_channels_list[i]) for j in range(block_per_branch)]
)
self.branches.append(branch)
# 多尺度融合模块
self.fuse_modules = nn.ModuleList()
for output_branch in range(num_branches):
fuse_ops = nn.ModuleList()
for input_branch in range(num_branches):
if input_branch == output_branch:
fuse_ops.append(nn.Identity())
elif input_branch < output_branch: # 高 -> 低,下采样
fuse_ops.append(nn.Sequential(
nn.Conv2d(out_channels_list[input_branch], out_channels_list[output_branch],
3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_channels_list[output_branch])
))
else: # 低 -> 高,上采样
fuse_ops.append(nn.Sequential(
nn.Conv2d(out_channels_list[input_branch], out_channels_list[output_branch],
1, bias=False),
nn.BatchNorm2d(out_channels_list[output_branch]),
nn.Upsample(scale_factor=2**(input_branch-output_branch), mode='nearest')
))
self.fuse_modules.append(fuse_ops)
def forward(self, x_list):
# x_list 是上一阶段各分支输出列表,长度 = num_branches
# 1. 各分支独立运算
x_list = [branch(x) for branch, x in zip(self.branches, x_list)]
# 2. 多尺度融合
fused = []
for out_branch in range(self.num_branches):
out = self.fuse_modules[out_branch][0](x_list[0])
for in_branch in range(1, self.num_branches):
transformed = self.fuse_modules[out_branch][in_branch](x_list[in_branch])
out = out + transformed
fused.append(F.relu(out))
return fused
# 简化模型:两个分支示例
model = HRNetStage(num_branches=2, block_per_branch=4,
in_channels_list=[32, 64], out_channels_list=[32, 64])
实际 HRNet 还包含用于下一次增加分支的 transition 层等,完整流程通常会封装成更标准的类。建议在实际使用时直接调用 torchvision.models.segmentation 或 mmpose 等库中的预训练模型。
七、训练与调优要点
7.1 数据预处理
- 人体姿态估计常用输入尺寸 256×192 或 384×288。输入图像应进行中心裁剪或缩放并保持宽高比。
- 关键点坐标需转换成热图(高斯核,标准差通常取 1~2 像素)。
7.2 损失函数与优化器
- 一般使用均方误差(MSE)热图损失,也有用 L1 损失。
- 优化器常用 Adam,初始学习率 1e-3,配合余弦退火或阶梯衰减。
- 训练时往往需要从 ImageNet 预训练权重加载,但 HRNet 设计不同,主干不直接对应标准 ResNet 权重的命名,需要特殊转换脚本或使用官方预训练模型。
7.3 精度提升技巧
- 数据增强:随机旋转(±40°)、缩放(0.7~1.3)、水平翻转、半身增强等。
- 后处理:将预测热图在高分辨率分支上进行 argmax,再乘以输出 stride(1/4 原图)映射回原图坐标;还可以对热图进行泰勒展开细化至亚像素精度。
- 多模型集成:不同 W 版本或不同输入尺寸的平均。
八、HRNet 的应用扩展
虽然 HRNet 因姿态估计而出名,但它的高分辨率表征特性使其能直接应用于:
- 语义分割:接一个轻量分割头,如 OCRNet,可以在 Cityscapes 上获得顶尖结果。
- 目标检测:作为主干附属于检测器,对定位精度有明显提升。
- 人脸关键点、手部姿态估计:同样效果好。
- 图像超分辨率与恢复:利用其保留高频细节的能力。
九、总结:何时选择 HRNet?
当你需要网络输出具有较高空间精准度,并且任务涉及关键点、轮廓、边缘等细节时,HRNet 是首选主干之一。它虽然计算量比普通 ResNet 稍大,但通过并行多分支设计,有效平衡了语义强度和空间分辨率,避免了编解码结构的固有弊端。现在许多主流姿态估计框架(MMPose、ViTPose 前身)都基于 HRNet 构建,其思想也启发了后续更高性能的网络。
希望本教程能帮助你理解 HRNet 的原理,并成功将其应用到自己的项目中。动手实践是最好的学习方式,快去尝试训练一个属于自己的姿态估计器吧!