理论上来说,掌握了如何推理,可以算是个披着狼皮的高级人工智能工程师了——(你虽然不懂原理,但你会用啊,打肿脸充胖子总是可以的),OK,那接下来就来讲一下,Pytorch应该如何推理
pytorch 推理
Pytorch
的推理过程一般为:保存训练后模型的网络参数文件——>读取网络参数模型文件——>将其导入Model模型中——>调用模型中的前向传播forward
函数进行推理
pytorch 读取文件推理的两个关键函数
torch.load()
读取.pth文件或者.pth.tar文件,以字典方式存储
model.load_state_dict()
model为网络模型,将模型文件参数载入模型
1. .pth文件
首先保证自己的路径中存在要读取的模型参数文件.pth文件,然后再读入其文件。
check=torch.load('srcnn_x2.pth')
用print函数看看里面包含了那些部分
print(check.keys())
可以看到,.pth读进来的字典中存储的是每一层值的结构参数。既然已经成功读取了模型参数文件,下一步要做的就是创建模型对象,将参数导入模型
model=SRCNN()
model.load_state_dict(check)
print('模型加载成功')
运行后终端窗口显示模型加载成功,说明参数已经成功导入模型,可以正常进行推理了
2. .pth.tar 文件
同样,首先先保证自己的路径中存在后缀为.pth.tar的对应模型文件,再进行读入
check=torch.load(mbv3_small.pth.tar)
同样也通过print函数打印读入的字典,看看其结构是否与.pth文件不同
print(check.keys())
可以看到,此时check字典中存的不再是网络中层模型的参数值,他储存的是整个训练网络的信息,例如epoch
,best_pre
,optimizer
等等,模型参数实际上存储在state_dict
中,我们可以对state_dict
键值进行进一步打印
print(check['state_dic'].keys())
可以看到,state_dic
中存储的是.pth文件中的内容,所以.pth.tar
文件可以看作是.pth
文件的扩展,最后再进行模型的读入即可
model=args.model
因为有些网络是用cuda训练的,所以创建时需要使用nn.DataParallel()
函数配合.cuda()
将模型转化为cuda形
model=nn.DataParallel(model).cuda()
model.load_state_dict(check['state_dict'])
最后终端显示模型加载成功,即模型参数导入完毕