pytorch加载自定义网络权重的实现
在将自定义的网络权重加载到网络中时,报错:
AttributeError:'dict'objecthasnoattribute'seek'.Youcanonlytorch.loadfromafilethatisseekable.Pleasepre-loadthedataintoabufferlikeio.BytesIOandtrytoloadfromitinstead.
我们一步一步分析。
模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')
(1)查看获取模型权重的源码:
pytorch源码:net.state_dict()
defstate_dict(self,destination=None,prefix='',keep_vars=False): r"""Returnsadictionarycontainingawholestateofthemodule. Bothparametersandpersistentbuffers(e.g.runningaverages)are included.Keysarecorrespondingparameterandbuffernames. Returns: dict: adictionarycontainingawholestateofthemodule Example:: >>>module.state_dict().keys() ['bias','weight'] """
将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!
(2)查看保存模型权重的源码:
pytorch源码:torch.save()
defsave(obj,f,pickle_module=pickle,pickle_protocol=DEFAULT_PROTOCOL): """Savesanobjecttoadiskfile. Seealso::ref:`recommend-saving-models` Args: obj:savedobject f:afile-likeobject(hastoimplementwriteandflush)orastring containingafilename pickle_module:moduleusedforpicklingmetadataandobjects pickle_protocol:canbespecifiedtooverridethedefaultprotocol ..warning:: IfyouareusingPython2,torch.savedoesNOTsupportStringIO.StringIO asavalidfile-likeobject.Thisisbecausethewritemethodshouldreturn thenumberofbyteswritten;StringIO.write()doesnotdothis. Pleaseusesomethinglikeio.BytesIOinstead.
函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。
解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()
#b为自定义的字典 torch.save(b,'new.pkl') net.load_state_dict(torch.load(b))
解决方法很简单,主要记录解决思路。
以上这篇pytorch加载自定义网络权重的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。