当前位置:主页 > 软件编程 > Python代码 >

pytorch模型存储的2种实现方法

时间:2021-11-24 13:58:52 | 栏目:Python代码 | 点击:

1、保存整个网络结构信息和模型参数信息:

torch.save(model_object, './model.pth')

直接加载即可使用:

model = torch.load('./model.pth')

2、只保存网络的模型参数-推荐使用

torch.save(model_object.state_dict(), './params.pth')

加载则要先从本地网络模块导入网络,然后再加载参数:

from models import AgeModel
model = AgeModel()
model.load_state_dict(torch.load('./params.pth'))

您可能感兴趣的文章:

相关文章