pytorch模型存储的2种实现方法
时间:2021-11-24 13:58:52|栏目:Python代码|点击: 次
1、保存整个网络结构信息和模型参数信息:
torch.save(model_object, './model.pth')
直接加载即可使用:
model = torch.load('./model.pth')
2、只保存网络的模型参数-推荐使用
torch.save(model_object.state_dict(), './params.pth')
加载则要先从本地网络模块导入网络,然后再加载参数:
from models import AgeModel
model = AgeModel()
model.load_state_dict(torch.load('./params.pth'))
上一篇:python将多个文本文件合并为一个文本的代码(便于搜索)
栏 目:Python代码
下一篇:tensorflow1.0学习之模型的保存与恢复(Saver)
本文标题:pytorch模型存储的2种实现方法
本文地址:http://www.codeinn.net/misctech/184659.html






