PyTorch Hub 实战:一行代码加载预训练模型
PyTorch Hub 是什么?
PyTorch Hub 是 PyTorch 官方提供的预训练模型库与模型共享平台。它的核心目标是让研究者与开发者只需一行代码,就能加载经过验证的预训练模型,极大降低复现论文结果和进行迁移学习的门槛。无论你关注的是计算机视觉、自然语言处理还是生成模型,都能在 Hub 中找到社区贡献的高质量模型。
核心优势
- 零配置加载:无需手动下载权重文件、编写模型定义,一行代码即可完成。
- 模型发现与探索:通过
torch.hub.list()可以快速列出某个仓库中所有可用的预训练模型。 - 可复现性:通过指定提交哈希或标签,确保每次加载的模型权重完全一致。
- 开放生态:任何人都可以发布经过测试的模型,形成共建生态。
环境准备与安装
PyTorch Hub 内置于 PyTorch 1.0 及以上版本,无需额外安装。但推荐使用较新版本以获得更好的模型覆盖和稳定性。
# 建议安装最新稳定版 PyTorch
pip install torch torchvision --upgrade
# 验证安装
python -c "import torch; print(torch.__version__)"
如果你的环境中还没有安装 PyTorch,请根据官方指南选择适合的版本。
一行代码加载预训练模型
PyTorch Hub 最经典的用法就是通过 torch.hub.load() 加载模型。基本语法如下:
model = torch.hub.load(repo_or_dir, model, *args, **kwargs)
repo_or_dir:GitHub 仓库标识(格式用户名/仓库名)或本地目录路径。model:仓库中hubconf.py定义的模型入口函数名。*args, **kwargs:传递给入口函数的参数,例如pretrained=True。
实例:加载 ResNet-18 预训练模型
以计算机视觉中常用的 ResNet-18 为例,一行代码即可获得 ImageNet 预训练权重:
import torch
# 从 pytorch/vision 仓库加载 resnet18 模型,并使用预训练权重
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
# 设置为评估模式
model.eval()
运行这段代码时,PyTorch 会自动:
- 从 GitHub 拉取
pytorch/vision仓库的v0.10.0标签对应的快照。 - 执行仓库根目录下的
hubconf.py,找到resnet18函数。 - 下载预训练权重(如果本地没有)并构造模型。
版本标签的作用:使用
pytorch/vision:v0.10.0确保即使官方仓库后续更新,你加载的模型行为也不会改变。省略标签则始终拉取主分支最新代码,可能带来兼容性风险。
其他常用模型示例
# 加载 MobileNet V2(适合移动端部署)
mobilenet = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
# 加载 DeepLabV3 语义分割模型(ResNet-101 骨干)
deeplab = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet101', pretrained=True)
# 加载生成对抗网络 DCGAN(来自 facebookresearch 仓库)
dcgan = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN')
探索可用的模型
在加载模型之前,你可以先查看某个仓库提供了哪些预训练模型:
# 列出 pytorch/vision 仓库 v0.10.0 版本中的所有模型入口
available_models = torch.hub.list('pytorch/vision:v0.10.0')
print(available_models)
# 输出示例:['alexnet', 'deeplabv3_resnet101', 'mobilenet_v2', 'resnet18', ...]
torch.hub.list() 会返回一个字符串列表,每个元素都是一个可直接供 torch.hub.load() 使用的 model 名称。
常用官方仓库速查
| 领域 | 仓库标识 | 说明 |
|---|---|---|
| 视觉 | pytorch/vision |
图像分类、目标检测、分割等经典模型 |
| 自然语言 | huggingface/pytorch-transformers(旧版) |
大量 Transformer 模型,现建议用 transformers 库 |
| 生成模型 | facebookresearch/pytorch_GAN_zoo |
DCGAN、CGAN 等生成对抗网络 |
| 语音 | snakers4/silero-models |
高质量语音识别与合成模型 |
| 推荐系统 | NVIDIA/DeepLearningExamples(部分模型) |
官方示例中包含的推荐模型 |
模型的使用与推理
加载模型后,通常需要输入符合模型要求的张量。以 ResNet-18 为例,它的输入是形状为 (N, 3, 224, 224) 的张量,像素值需归一化到特定的均值和标准差。
import torch
from PIL import Image
from torchvision import transforms
# 加载模型
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()
# 图像预处理管线
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载本地图片
img = Image.open('cat.jpg')
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0) # 增加 batch 维度
# 如果有 GPU,将模型和输入移至 GPU
if torch.cuda.is_available():
model = model.to('cuda')
input_batch = input_batch.to('cuda')
# 推理
with torch.no_grad():
output = model(input_batch)
# output 是 1000 类的 logits
probabilities = torch.nn.functional.softmax(output[0], dim=0)
对于语义分割模型(如 DeepLabV3),输出是类别掩码,处理方式略有不同:
deeplab = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet101', pretrained=True)
deeplab.eval()
# 预处理与上面类似,但输出是 OrderedDict,包含 'out' 键
output = deeplab(input_batch)['out']
# output 形状为 (1, 21, H, W),取每个像素 argmax 得到类别图
将模型用于迁移学习
预训练模型最大的价值在于特征复用。你可以轻松修改模型的最后几层,以适应自己的任务。
import torch.nn as nn
# 加载 ResNet-18 并固定特征提取层参数
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
for param in model.parameters():
param.requires_grad = False
# 替换最后的全连接层,假设我们的任务是 10 分类
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
# 只训练新层
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
如果希望微调整个网络,可以在训练后期解冻部分层,这完全取决于你的任务和计算资源。
发布自己的模型到 PyTorch Hub
想让自己的模型被社区使用?只需在你的 GitHub 仓库根目录添加一个 hubconf.py 文件。
1. 编写 hubconf.py
hubconf.py 中定义任意多个可被调用的函数,每个函数都返回一个 PyTorch 模型(nn.Module)。函数名就是模型入口名。
# hubconf.py
import torch
from my_model import MyAwesomeModel
def my_model(pretrained=False, **kwargs):
"""返回 MyAwesomeModel 实例。"""
model = MyAwesomeModel(**kwargs)
if pretrained:
# 从可访问的 URL 加载预训练权重
checkpoint = torch.hub.load_state_dict_from_url('https://example.com/weights.pth', progress=True)
model.load_state_dict(checkpoint)
return model
2. 发布指南
- 将你的仓库设为公开,并确保
hubconf.py、模型定义文件、依赖列表(requirements.txt)都在根目录或能正确导入。 - 最好使用版本标签(如
v1.0)标记稳定发布,这样用户通过用户名/仓库名:v1.0加载时能得到确定性的结果。 - 建议在
hubconf.py中导入依赖时做好异常处理,给出清晰的错误提示。
3. 使用方式
其他人只需一行代码即可加载你的模型:
model = torch.hub.load('your_github_username/your_repo:v1.0', 'my_model', pretrained=True)
高级技巧与常见问题
强制重新加载与缓存管理
PyTorch Hub 会将下载的仓库快照和权重缓存到本地(通常为 ~/.cache/torch/hub/)。如果你修改了 hubconf.py 或想强制刷新,可以设置 force_reload=True。
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', force_reload=True)
使用 torch.hub.set_dir() 可以更改缓存目录:
torch.hub.set_dir('/my/custom/cache')
注意:
force_reload会重新下载仓库代码,但权重文件如果已存在则默认不重复下载,除非你同时设置环境变量TORCH_HOME相关的策略。
离线使用
在无网络环境中,你可以:
- 在有网络时先加载一次模型,让缓存生成。
- 将整个缓存目录(
~/.cache/torch/hub/)复制到目标机器相同路径。 - 确保仓库代码和权重文件都齐全,即可离线加载。
指定模型保存路径
通过 torch.hub.load_state_dict_from_url(model_url, model_dir='./custom_weights') 可以在发布模型时自定义权重下载路径。
安全性考虑
从未知来源的仓库加载模型时,请注意 hubconf.py 会作为 Python 代码被执行。请只加载你信任的仓库。官方仓库和知名研究机构的仓库通常经过了社区审查。
性能提示
- 频繁创建模型时,推荐重用已加载的模型实例,避免重复 I/O。
- 推理前调用
model.eval()以禁用 dropout 和 batch normalization 的训练行为。 - 使用
torch.no_grad()上下文管理器关闭梯度计算,节省显存和计算量。
总结
PyTorch Hub 将“一行代码使用预训练模型”变成了现实,极大地加速了研究原型设计与工业应用落地。通过本文,你学会了:
- 如何使用
torch.hub.load()快速加载视觉、NLP 等领域的高质量模型。 - 如何探索可用模型并正确进行预处理与推理。
- 如何在迁移学习场景中微调预训练模型。
- 如何将自己的模型发布到 Hub 供全球开发者使用。
现在,打开你的 Python 环境,用一行代码开启预训练模型的强大功能吧!