PyTorch Hub是一个由社区驱动的资源库,它汇集了各种预训练的深度学习模型,提供了简单易用的接口,使得能够直接在自己的代码中利用这些预先训练好的模型,无需从头开始训练。同时还支持Colab,能与论文代码结合网站Paper With Code集成,用于更广泛的研究。
官网:PyTorch Hub
GitHub: GitHub pytorch/hub
1 基本使用
PyTorch Hub的使用非常简单,无需下载模型,只需调用torch.hun.load()
即可
【例】下载预训练的ResNet101模型
import torch
model = torch.hub.load('pytorch/vision:v0.4.2', 'deeplabv3_resnet101', pretrained=True)
print(model.eval())
Using cache found in /Users/hyperplasma/.cache/torch/hub/pytorch_vision_v0.4.2
Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/hyperplasma/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
100%|██████████| 170M/170M [00:39<00:00, 4.49MB/s]
Downloading: "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth" to /Users/hyperplasma/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth
100%|██████████| 233M/233M [00:54<00:00, 4.48MB/s]
DeepLabV3(
...
)
2 加载YOLOV5模型
https://pytorch.org/hub/ultralytics_yolov5/
环境要求:Python>=3.8,PyTorch>=1.7
下载YOLOv5依赖:
pip install -U ultralytics
此处以一个预训练的YOLOv5s模型为例。YOLOv5接受URL、文件名、PIL、OpenCV、NumPy、PyTorch输入,返回torch、pandas、JSON格式的输出。
import torch
# Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
# Images
imgs = ['https://ultralytics.com/images/zidane.jpg'] # batch of images
# Inference
results = model(imgs)
# Results
results.print()
results.save() # or .show()
...
YOLOv5 🚀 2024-9-2 Python-3.11.5 torch-2.1.2 CPU
Downloading https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s.pt to yolov5s.pt...
100%|██████████| 14.1M/14.1M [00:03<00:00, 4.53MB/s]
Fusing layers...
YOLOv5s summary: 213 layers, 7225885 parameters, 0 gradients, 16.4 GFLOPs
Adding AutoShape...
image 1/1: 720x1280 2 persons, 1 tie, 1 cell phone
Speed: 3987.2ms pre-process, 67.8ms inference, 5.8ms NMS per image at shape (1, 3, 384, 640)
Saved 1 image to runs/detect/exp
results.xyxy[0] # img1 predictions (tensor)
tensor([[7.45579e+02, 4.84703e+01, 1.14269e+03, 7.20000e+02, 8.68910e-01, 0.00000e+00],
[1.24744e+02, 1.97335e+02, 8.44397e+02, 7.16651e+02, 6.30324e-01, 0.00000e+00],
[4.41239e+02, 4.39351e+02, 4.98381e+02, 7.08571e+02, 6.16793e-01, 2.70000e+01],
[5.94082e+02, 3.77300e+02, 6.35424e+02, 4.37148e+02, 2.74014e-01, 6.70000e+01]])
results.pandas().xyxy[0] # img1 predictions (pandas)
xmin ymin xmax ymax confidence class \
0 745.578674 48.470337 1142.694214 720.000000 0.868910 0
1 124.744324 197.334564 844.396912 716.650513 0.630324 0
2 441.238708 439.350647 498.380737 708.570984 0.616793 27
3 594.081787 377.300323 635.423950 437.147797 0.274014 67
name
0 person
1 person
2 tie
3 cell phone