欢迎来到代码驿站!

Python代码

当前位置:首页 > 软件编程 > Python代码

pytorch模型存储的2种实现方法

时间:2021-11-24 13:58:52|栏目:Python代码|点击:

1、保存整个网络结构信息和模型参数信息:

torch.save(model_object, './model.pth')

直接加载即可使用:

model = torch.load('./model.pth')

2、只保存网络的模型参数-推荐使用

torch.save(model_object.state_dict(), './params.pth')

加载则要先从本地网络模块导入网络,然后再加载参数:

from models import AgeModel
model = AgeModel()
model.load_state_dict(torch.load('./params.pth'))

上一篇:python将多个文本文件合并为一个文本的代码(便于搜索)

栏    目:Python代码

下一篇:tensorflow1.0学习之模型的保存与恢复(Saver)

本文标题:pytorch模型存储的2种实现方法

本文地址:http://www.codeinn.net/misctech/184659.html

推荐教程

广告投放 | 联系我们 | 版权申明

重要申明:本站所有的文章、图片、评论等,均由网友发表或上传并维护或收集自网络,属个人行为,与本站立场无关。

如果侵犯了您的权利,请与我们联系,我们将在24小时内进行处理、任何非本站因素导致的法律后果,本站均不负任何责任。

联系QQ:914707363 | 邮箱:codeinn#126.com(#换成@)

Copyright © 2020 代码驿站 版权所有