您当前的位置:首页 > IT编程 > 图像修复
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch |

自学教程:Pytorch快速下载预训练模型并修改保存路径

51自学网 2020-06-04 11:19:45
  图像修复
这篇教程Pytorch快速下载预训练模型并修改保存路径写得很实用,希望能帮到您。
【Pytorch】快速下载预训练模型并修改保存路径


首次用Pytorch加载预训练模型,需要在线下载,但是下载速度比较慢。下载后会保存在本地缓存里。如果能直接加载本地下载好的模型就会快了,主要是个修改路径的问题。

所以要提升速度一般有两种方法:
1.修改torch源码,一次性改变下载url
2.将离线模型权重存到缓存文件夹里
参考:pytorch预训练模型的下载地址以及解决下载速度慢的方法 - you-wh - 博客园
参考:【Pytorch】加载torchvision中预训练好的模型并修改默认下载路径_ProLover98的博客-CSDN博客
参考:pytorch 加载(.pth)格式的模型_人工智能_u014264373的博客-CSDN博客(没有修改存储路径)
但是用云服务器时候这两种方法都有点问题。如果预训练模型的下载路径和存储路径能随用随改就最好了。

所以以vgg16为例,本文采用的方法是:

import torch
from torchvision import models
pthfile = 'file:///mnt/model/vgg16-397923af.pth'  #在下载好的pth文件路径前加file:///得到url
pthsavefile = '/mnt/vgg16-397923af.pth'  #这是模型保存的路径
model = models.vgg.vgg16(pretrained=False, progress=True) #定义一个不需要预训练的模型。如果pretrained=True就会自动下载了
state_dict = torch.utils.model_zoo.load_url(pthfile, model_dir=pthsavefile,
    map_location=None, progress=True, check_hash=False)
# 从pthfile下载到pthsavefile。默认model_dir为none
model.load_state_dict(state_dict) # 读取下载好的模型
# 设置好参数就可以train了
model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25)

   

模型可以任意换成别的,比如

models.vgg.vgg16
models.resnet.resnet18
models.resnet.resnext50_32x4d
 
pytorch实现从本地加载 .pth 格式模型
CelebA数据集详细介绍及其属性提取源代码
51自学网,即我要自学网,自学EXCEL、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。
京ICP备13026421号-1