PyTorch Hub 实战:一行代码加载预训练模型

FreeGuideOnline 最新 2026-06-27

PyTorch Hub 是什么?

PyTorch Hub 是 PyTorch 官方提供的预训练模型库与模型共享平台。它的核心目标是让研究者与开发者只需一行代码,就能加载经过验证的预训练模型,极大降低复现论文结果和进行迁移学习的门槛。无论你关注的是计算机视觉、自然语言处理还是生成模型,都能在 Hub 中找到社区贡献的高质量模型。

核心优势

  • 零配置加载:无需手动下载权重文件、编写模型定义,一行代码即可完成。
  • 模型发现与探索:通过 torch.hub.list() 可以快速列出某个仓库中所有可用的预训练模型。
  • 可复现性:通过指定提交哈希或标签,确保每次加载的模型权重完全一致。
  • 开放生态:任何人都可以发布经过测试的模型,形成共建生态。

环境准备与安装

PyTorch Hub 内置于 PyTorch 1.0 及以上版本,无需额外安装。但推荐使用较新版本以获得更好的模型覆盖和稳定性。

# 建议安装最新稳定版 PyTorch
pip install torch torchvision --upgrade

# 验证安装
python -c "import torch; print(torch.__version__)"

如果你的环境中还没有安装 PyTorch,请根据官方指南选择适合的版本。

一行代码加载预训练模型

PyTorch Hub 最经典的用法就是通过 torch.hub.load() 加载模型。基本语法如下:

model = torch.hub.load(repo_or_dir, model, *args, **kwargs)
  • repo_or_dir:GitHub 仓库标识(格式 用户名/仓库名)或本地目录路径。
  • model:仓库中 hubconf.py 定义的模型入口函数名。
  • *args, **kwargs:传递给入口函数的参数,例如 pretrained=True

实例:加载 ResNet-18 预训练模型

以计算机视觉中常用的 ResNet-18 为例,一行代码即可获得 ImageNet 预训练权重:

import torch

# 从 pytorch/vision 仓库加载 resnet18 模型,并使用预训练权重
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)

# 设置为评估模式
model.eval()

运行这段代码时,PyTorch 会自动:

  1. 从 GitHub 拉取 pytorch/vision 仓库的 v0.10.0 标签对应的快照。
  2. 执行仓库根目录下的 hubconf.py,找到 resnet18 函数。
  3. 下载预训练权重(如果本地没有)并构造模型。

版本标签的作用:使用 pytorch/vision:v0.10.0 确保即使官方仓库后续更新,你加载的模型行为也不会改变。省略标签则始终拉取主分支最新代码,可能带来兼容性风险。

其他常用模型示例

# 加载 MobileNet V2(适合移动端部署)
mobilenet = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)

# 加载 DeepLabV3 语义分割模型(ResNet-101 骨干)
deeplab = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet101', pretrained=True)

# 加载生成对抗网络 DCGAN(来自 facebookresearch 仓库)
dcgan = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN')

探索可用的模型

在加载模型之前,你可以先查看某个仓库提供了哪些预训练模型:

# 列出 pytorch/vision 仓库 v0.10.0 版本中的所有模型入口
available_models = torch.hub.list('pytorch/vision:v0.10.0')
print(available_models)
# 输出示例:['alexnet', 'deeplabv3_resnet101', 'mobilenet_v2', 'resnet18', ...]

torch.hub.list() 会返回一个字符串列表,每个元素都是一个可直接供 torch.hub.load() 使用的 model 名称。

常用官方仓库速查

领域 仓库标识 说明
视觉 pytorch/vision 图像分类、目标检测、分割等经典模型
自然语言 huggingface/pytorch-transformers(旧版) 大量 Transformer 模型,现建议用 transformers
生成模型 facebookresearch/pytorch_GAN_zoo DCGAN、CGAN 等生成对抗网络
语音 snakers4/silero-models 高质量语音识别与合成模型
推荐系统 NVIDIA/DeepLearningExamples(部分模型) 官方示例中包含的推荐模型

模型的使用与推理

加载模型后,通常需要输入符合模型要求的张量。以 ResNet-18 为例,它的输入是形状为 (N, 3, 224, 224) 的张量,像素值需归一化到特定的均值和标准差。

import torch
from PIL import Image
from torchvision import transforms

# 加载模型
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()

# 图像预处理管线
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载本地图片
img = Image.open('cat.jpg')
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0)  # 增加 batch 维度

# 如果有 GPU,将模型和输入移至 GPU
if torch.cuda.is_available():
    model = model.to('cuda')
    input_batch = input_batch.to('cuda')

# 推理
with torch.no_grad():
    output = model(input_batch)
# output 是 1000 类的 logits
probabilities = torch.nn.functional.softmax(output[0], dim=0)

对于语义分割模型(如 DeepLabV3),输出是类别掩码,处理方式略有不同:

deeplab = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet101', pretrained=True)
deeplab.eval()

# 预处理与上面类似,但输出是 OrderedDict,包含 'out' 键
output = deeplab(input_batch)['out']
# output 形状为 (1, 21, H, W),取每个像素 argmax 得到类别图

将模型用于迁移学习

预训练模型最大的价值在于特征复用。你可以轻松修改模型的最后几层,以适应自己的任务。

import torch.nn as nn

# 加载 ResNet-18 并固定特征提取层参数
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
for param in model.parameters():
    param.requires_grad = False

# 替换最后的全连接层,假设我们的任务是 10 分类
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)

# 只训练新层
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)

如果希望微调整个网络,可以在训练后期解冻部分层,这完全取决于你的任务和计算资源。

发布自己的模型到 PyTorch Hub

想让自己的模型被社区使用?只需在你的 GitHub 仓库根目录添加一个 hubconf.py 文件。

1. 编写 hubconf.py

hubconf.py 中定义任意多个可被调用的函数,每个函数都返回一个 PyTorch 模型(nn.Module)。函数名就是模型入口名。

# hubconf.py
import torch
from my_model import MyAwesomeModel

def my_model(pretrained=False, **kwargs):
    """返回 MyAwesomeModel 实例。"""
    model = MyAwesomeModel(**kwargs)
    if pretrained:
        # 从可访问的 URL 加载预训练权重
        checkpoint = torch.hub.load_state_dict_from_url('https://example.com/weights.pth', progress=True)
        model.load_state_dict(checkpoint)
    return model

2. 发布指南

  • 将你的仓库设为公开,并确保 hubconf.py、模型定义文件、依赖列表(requirements.txt)都在根目录或能正确导入。
  • 最好使用版本标签(如 v1.0)标记稳定发布,这样用户通过 用户名/仓库名:v1.0 加载时能得到确定性的结果。
  • 建议在 hubconf.py 中导入依赖时做好异常处理,给出清晰的错误提示。

3. 使用方式

其他人只需一行代码即可加载你的模型:

model = torch.hub.load('your_github_username/your_repo:v1.0', 'my_model', pretrained=True)

高级技巧与常见问题

强制重新加载与缓存管理

PyTorch Hub 会将下载的仓库快照和权重缓存到本地(通常为 ~/.cache/torch/hub/)。如果你修改了 hubconf.py 或想强制刷新,可以设置 force_reload=True

model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', force_reload=True)

使用 torch.hub.set_dir() 可以更改缓存目录:

torch.hub.set_dir('/my/custom/cache')

注意force_reload 会重新下载仓库代码,但权重文件如果已存在则默认不重复下载,除非你同时设置环境变量 TORCH_HOME 相关的策略。

离线使用

在无网络环境中,你可以:

  1. 在有网络时先加载一次模型,让缓存生成。
  2. 将整个缓存目录(~/.cache/torch/hub/)复制到目标机器相同路径。
  3. 确保仓库代码和权重文件都齐全,即可离线加载。

指定模型保存路径

通过 torch.hub.load_state_dict_from_url(model_url, model_dir='./custom_weights') 可以在发布模型时自定义权重下载路径。

安全性考虑

从未知来源的仓库加载模型时,请注意 hubconf.py 会作为 Python 代码被执行。请只加载你信任的仓库。官方仓库和知名研究机构的仓库通常经过了社区审查。

性能提示

  • 频繁创建模型时,推荐重用已加载的模型实例,避免重复 I/O。
  • 推理前调用 model.eval() 以禁用 dropout 和 batch normalization 的训练行为。
  • 使用 torch.no_grad() 上下文管理器关闭梯度计算,节省显存和计算量。

总结

PyTorch Hub 将“一行代码使用预训练模型”变成了现实,极大地加速了研究原型设计与工业应用落地。通过本文,你学会了:

  • 如何使用 torch.hub.load() 快速加载视觉、NLP 等领域的高质量模型。
  • 如何探索可用模型并正确进行预处理与推理。
  • 如何在迁移学习场景中微调预训练模型。
  • 如何将自己的模型发布到 Hub 供全球开发者使用。

现在,打开你的 Python 环境,用一行代码开启预训练模型的强大功能吧!