pytorch动态网络以及权重共享实例
pytorch动态网络+权值共享
pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术:
#-*-coding:utf-8-*- importrandom importtorch classDynamicNet(torch.nn.Module): def__init__(self,D_in,H,D_out): """ 这里构造了几个向前传播过程中用到的线性函数 """ super(DynamicNet,self).__init__() self.input_linear=torch.nn.Linear(D_in,H) self.middle_linear=torch.nn.Linear(H,H) self.output_linear=torch.nn.Linear(H,D_out) defforward(self,x): """ Fortheforwardpassofthemodel,werandomlychooseeither0,1,2,or3 andreusethemiddle_linearModulethatmanytimestocomputehiddenlayer representations. Sinceeachforwardpassbuildsadynamiccomputationgraph,wecanusenormal Pythoncontrol-flowoperatorslikeloopsorconditionalstatementswhen definingtheforwardpassofthemodel. HerewealsoseethatitisperfectlysafetoreusethesameModulemany timeswhendefiningacomputationalgraph.ThisisabigimprovementfromLua Torch,whereeachModulecouldbeusedonlyonce. 这里中间层每次向前过程中都是随机添加0-3层,而且中间层都是使用的同一个线性层,这样计算时,权值也是用的同一个。 """ h_relu=self.input_linear(x).clamp(min=0) for_inrange(random.randint(0,3)): h_relu=self.middle_linear(h_relu).clamp(min=0) y_pred=self.output_linear(h_relu) returny_pred #Nisbatchsize;D_inisinputdimension; #Hishiddendimension;D_outisoutputdimension. N,D_in,H,D_out=64,1000,100,10 #CreaterandomTensorstoholdinputsandoutputs x=torch.randn(N,D_in) y=torch.randn(N,D_out) #Constructourmodelbyinstantiatingtheclassdefinedabove model=DynamicNet(D_in,H,D_out) #ConstructourlossfunctionandanOptimizer.Trainingthisstrangemodelwith #vanillastochasticgradientdescentistough,soweusemomentum criterion=torch.nn.MSELoss(reduction='sum') optimizer=torch.optim.SGD(model.parameters(),lr=1e-4,momentum=0.9) fortinrange(500): #Forwardpass:Computepredictedybypassingxtothemodel y_pred=model(x) #Computeandprintloss loss=criterion(y_pred,y) print(t,loss.item()) #Zerogradients,performabackwardpass,andupdatetheweights. optimizer.zero_grad() loss.backward() optimizer.step()
这个程序实际上是一种RNN结构,在执行过程中动态的构建计算图
References:PytorchDocumentations.
以上这篇pytorch动态网络以及权重共享实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。