1st's Studio.

Pytorch推理相关

字数统计: 626阅读时长: 2 min
2022/04/07

理论上来说,掌握了如何推理,可以算是个披着狼皮的高级人工智能工程师了——(你虽然不懂原理,但你会用啊,打肿脸充胖子总是可以的),OK,那接下来就来讲一下,Pytorch应该如何推理

pytorch 推理

Pytorch 的推理过程一般为:保存训练后模型的网络参数文件——>读取网络参数模型文件——>将其导入Model模型中——>调用模型中的前向传播forward函数进行推理

pytorch 读取文件推理的两个关键函数

  • torch.load()

读取.pth文件或者.pth.tar文件,以字典方式存储

  • model.load_state_dict()

model为网络模型,将模型文件参数载入模型

1. .pth文件

首先保证自己的路径中存在要读取的模型参数文件.pth文件,然后再读入其文件。

check=torch.load('srcnn_x2.pth')

用print函数看看里面包含了那些部分

print(check.keys())

可以看到,.pth读进来的字典中存储的是每一层值的结构参数。既然已经成功读取了模型参数文件,下一步要做的就是创建模型对象,将参数导入模型

model=SRCNN()

model.load_state_dict(check)

print('模型加载成功')

运行后终端窗口显示模型加载成功,说明参数已经成功导入模型,可以正常进行推理了

2. .pth.tar 文件

同样,首先先保证自己的路径中存在后缀为.pth.tar的对应模型文件,再进行读入

check=torch.load(mbv3_small.pth.tar)

同样也通过print函数打印读入的字典,看看其结构是否与.pth文件不同

print(check.keys())

可以看到,此时check字典中存的不再是网络中层模型的参数值,他储存的是整个训练网络的信息,例如epochbest_pre,optimizer等等,模型参数实际上存储在state_dict中,我们可以对state_dict键值进行进一步打印

print(check['state_dic'].keys())

可以看到,state_dic中存储的是.pth文件中的内容,所以.pth.tar文件可以看作是.pth文件的扩展,最后再进行模型的读入即可

model=args.model

因为有些网络是用cuda训练的,所以创建时需要使用nn.DataParallel()函数配合.cuda()将模型转化为cuda形

model=nn.DataParallel(model).cuda() model.load_state_dict(check['state_dict'])

最后终端显示模型加载成功,即模型参数导入完毕

CATALOG
  1. 1. pytorch 推理
    1. 1.1. pytorch 读取文件推理的两个关键函数
      1. 1.1.1. 1. .pth文件
      2. 1.1.2. 2. .pth.tar 文件