pytorch中获取模型input/output shape实例
Pytorch官方目前无法像tensorflow,caffe那样直接给出shape信息,详见
https://github.com/pytorch/pytorch/pull/3043
以下代码算一种workaround。由于CNN,RNN等模块实现不一样,添加其他模块支持可能需要改代码。
例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。
该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape
#coding:utf-8 fromcollectionsimportOrderedDict importtorch fromtorch.autogradimportVariable importtorch.nnasnn importmodels.crnnascrnn importjson defget_output_size(summary_dict,output): ifisinstance(output,tuple): foriinxrange(len(output)): summary_dict[i]=OrderedDict() summary_dict[i]=get_output_size(summary_dict[i],output[i]) else: summary_dict['output_shape']=list(output.size()) returnsummary_dict defsummary(input_size,model): defregister_hook(module): defhook(module,input,output): class_name=str(module.__class__).split('.')[-1].split("'")[0] module_idx=len(summary) m_key='%s-%i'%(class_name,module_idx+1) summary[m_key]=OrderedDict() summary[m_key]['input_shape']=list(input[0].size()) summary[m_key]=get_output_size(summary[m_key],output) params=0 ifhasattr(module,'weight'): params+=torch.prod(torch.LongTensor(list(module.weight.size()))) ifmodule.weight.requires_grad: summary[m_key]['trainable']=True else: summary[m_key]['trainable']=False #ifhasattr(module,'bias'): #params+=torch.prod(torch.LongTensor(list(module.bias.size()))) summary[m_key]['nb_params']=params ifnotisinstance(module,nn.Sequential)and\ notisinstance(module,nn.ModuleList)and\ not(module==model): hooks.append(module.register_forward_hook(hook)) #checkiftherearemultipleinputstothenetwork ifisinstance(input_size[0],(list,tuple)): x=[Variable(torch.rand(1,*in_size))forin_sizeininput_size] else: x=Variable(torch.rand(1,*input_size)) #createproperties summary=OrderedDict() hooks=[] #registerhook model.apply(register_hook) #makeaforwardpass model(x) #removethesehooks forhinhooks: h.remove() returnsummary crnn=crnn.CRNN(32,1,3755,256,1) x=summary([1,32,128],crnn) printjson.dumps(x)
以pytorch版CRNN为例,输出shape如下
{ "Conv2d-1":{ "input_shape":[1,1,32,128], "output_shape":[1,64,32,128], "trainable":true, "nb_params":576 }, "ReLU-2":{ "input_shape":[1,64,32,128], "output_shape":[1,64,32,128], "nb_params":0 }, "MaxPool2d-3":{ "input_shape":[1,64,32,128], "output_shape":[1,64,16,64], "nb_params":0 }, "Conv2d-4":{ "input_shape":[1,64,16,64], "output_shape":[1,128,16,64], "trainable":true, "nb_params":73728 }, "ReLU-5":{ "input_shape":[1,128,16,64], "output_shape":[1,128,16,64], "nb_params":0 }, "MaxPool2d-6":{ "input_shape":[1,128,16,64], "output_shape":[1,128,8,32], "nb_params":0 }, "Conv2d-7":{ "input_shape":[1,128,8,32], "output_shape":[1,256,8,32], "trainable":true, "nb_params":294912 }, "BatchNorm2d-8":{ "input_shape":[1,256,8,32], "output_shape":[1,256,8,32], "trainable":true, "nb_params":256 }, "ReLU-9":{ "input_shape":[1,256,8,32], "output_shape":[1,256,8,32], "nb_params":0 }, "Conv2d-10":{ "input_shape":[1,256,8,32], "output_shape":[1,256,8,32], "trainable":true, "nb_params":589824 }, "ReLU-11":{ "input_shape":[1,256,8,32], "output_shape":[1,256,8,32], "nb_params":0 }, "MaxPool2d-12":{ "input_shape":[1,256,8,32], "output_shape":[1,256,4,33], "nb_params":0 }, "Conv2d-13":{ "input_shape":[1,256,4,33], "output_shape":[1,512,4,33], "trainable":true, "nb_params":1179648 }, "BatchNorm2d-14":{ "input_shape":[1,512,4,33], "output_shape":[1,512,4,33], "trainable":true, "nb_params":512 }, "ReLU-15":{ "input_shape":[1,512,4,33], "output_shape":[1,512,4,33], "nb_params":0 }, "Conv2d-16":{ "input_shape":[1,512,4,33], "output_shape":[1,512,4,33], "trainable":true, "nb_params":2359296 }, "ReLU-17":{ "input_shape":[1,512,4,33], "output_shape":[1,512,4,33], "nb_params":0 }, "MaxPool2d-18":{ "input_shape":[1,512,4,33], "output_shape":[1,512,2,34], "nb_params":0 }, "Conv2d-19":{ "input_shape":[1,512,2,34], "output_shape":[1,512,1,33], "trainable":true, "nb_params":1048576 }, "BatchNorm2d-20":{ "input_shape":[1,512,1,33], "output_shape":[1,512,1,33], "trainable":true, "nb_params":512 }, "ReLU-21":{ "input_shape":[1,512,1,33], "output_shape":[1,512,1,33], "nb_params":0 }, "LSTM-22":{ "input_shape":[33,1,512], "0":{ "output_shape":[33,1,512] }, "1":{ "0":{ "output_shape":[2,1,256] }, "1":{ "output_shape":[2,1,256] } }, "nb_params":0 }, "Linear-23":{ "input_shape":[33,512], "output_shape":[33,256], "trainable":true, "nb_params":131072 }, "BidirectionalLSTM-24":{ "input_shape":[33,1,512], "output_shape":[33,1,256], "nb_params":0 }, "LSTM-25":{ "input_shape":[33,1,256], "0":{ "output_shape":[33,1,512] }, "1":{ "0":{ "output_shape":[2,1,256] }, "1":{ "output_shape":[2,1,256] } }, "nb_params":0 }, "Linear-26":{ "input_shape":[33,512], "output_shape":[33,3755], "trainable":true, "nb_params":1922560 }, "BidirectionalLSTM-27":{ "input_shape":[33,1,256], "output_shape":[33,1,3755], "nb_params":0 } }
以上这篇pytorch中获取模型input/outputshape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。