Pytorch 中retain_graph的用法详解
用法分析
在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么?
############################ #(1)UpdateDnetwork:maximizeD(x)-1-D(G(z)) ########################### real_img=Variable(target) iftorch.cuda.is_available(): real_img=real_img.cuda() z=Variable(data) iftorch.cuda.is_available(): z=z.cuda() fake_img=netG(z) netD.zero_grad() real_out=netD(real_img).mean() fake_out=netD(fake_img).mean() d_loss=1-real_out+fake_out d_loss.backward(retain_graph=True)##### optimizerD.step() ############################ #(2)UpdateGnetwork:minimize1-D(G(z))+PerceptionLoss+ImageLoss+TVLoss ########################### netG.zero_grad() g_loss=generator_criterion(fake_out,fake_img,real_img) g_loss.backward() optimizerG.step() fake_img=netG(z) fake_out=netD(fake_img).mean() g_loss=generator_criterion(fake_out,fake_img,real_img) running_results['g_loss']+=g_loss.data[0]*batch_size d_loss=1-real_out+fake_out running_results['d_loss']+=d_loss.data[0]*batch_size running_results['d_score']+=real_out.data[0]*batch_size running_results['g_score']+=fake_out.data[0]*batch_size
在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;
其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它,
如下代码:
importtorch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward() output2.backward()
输出如下错误信息:
--------------------------------------------------------------------------- RuntimeErrorTraceback(mostrecentcalllast)in () ---->1output1.backward() 2output2.backward() D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.pyinbackward(self,gradient,retain_graph,create_graph) 91products.Defaultsto``False``. 92""" --->93torch.autograd.backward(self,gradient,retain_graph,create_graph) 94 95defregister_hook(self,hook): D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.pyinbackward(tensors,grad_tensors,retain_graph,create_graph,grad_variables) 88Variable._execution_engine.run_backward( 89tensors,grad_tensors,retain_graph,create_graph, --->90allow_unreachable=True)#allow_unreachableflag 91 92 RuntimeError:Tryingtobackwardthroughthegraphasecondtime,butthebuffershavealreadybeenfreed.Specifyretain_graph=Truewhencallingbackwardthefirsttime.
修改成如下正确:
importtorch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward(retain_graph=True) output2.backward()
#假如你有两个Loss,先执行第一个的backward,再执行第二个backward loss1.backward(retain_graph=True) loss2.backward()#执行完这个后,所有中间变量都会被释放,以便下一次的循环 optimizer.step()#更新参数
Variable类源代码
classVariable(_C._VariableBase):
"""
Attributes:
data:任意类型的封装好的张量。
grad:保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。
requires_grad:标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。
volatile:标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。
is_leaf:标记变量是否是图叶子(如由用户创建的变量)的bool值.
grad_fn:Gradientfunctiongraphtrace.
Parameters:
data(anytensorclass):要包装的张量.
requires_grad(bool):bool型的标记值.**Keywordonly.**
volatile(bool):bool型的标记值.**Keywordonly.**
"""
defbackward(self,gradient=None,retain_graph=None,create_graph=None,retain_variables=None):
"""计算关于当前图叶子变量的梯度,图使用链式法则导致分化
如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数
如果变量不是标量(包含多个元素数据的矢量)且需要梯度,函数需要额外的梯度;
需要指定一个和tensor的形状匹配的grad_output参数(y在指定方向投影对x的导数);
可以是一个类型和位置相匹配且包含与自身相关的不同函数梯度的张量。
函数在叶子上累积梯度,调用前需要对该叶子进行清零。
Arguments:
grad_variables(Tensor,VariableorNone):
变量的梯度,如果是一个张量,除非“create_graph”是True,否则会自动转换成volatile型的变量。
可以为标量变量或不需要grad的值指定None值。如果None值可接受,则此参数可选。
retain_graph(bool,optional):如果为False,用来计算梯度的图将被释放。
在几乎所有情况下,将此选项设置为True不是必需的,通常可以以更有效的方式解决。
默认值为create_graph的值。
create_graph(bool,optional):为True时,会构造一个导数的图,用来计算出更高阶导数结果。
默认为False,除非``gradient``是一个volatile变量。
"""
torch.autograd.backward(self,gradient,retain_graph,create_graph,retain_variables)
defregister_hook(self,hook):
"""Registersabackwardhook.
每当与variable相关的梯度被计算时调用hook,hook的申明:hook(grad)->VariableorNone
不能对hook的参数进行修改,但可以选择性地返回一个新的梯度以用在`grad`的相应位置。
函数返回一个handle,其``handle.remove()``方法用于将hook从模块中移除。
Example:
>>>v=Variable(torch.Tensor([0,0,0]),requires_grad=True)
>>>h=v.register_hook(lambdagrad:grad*2)#doublethegradient
>>>v.backward(torch.Tensor([1,1,1]))
>>>v.grad.data
2
2
2
[torch.FloatTensorofsize3]
>>>h.remove()#removesthehook
"""
ifself.volatile:
raiseRuntimeError("cannotregisterahookonavolatilevariable")
ifnotself.requires_grad:
raiseRuntimeError("cannotregisterahookonavariablethat"
"doesn'trequiregradient")
ifself._backward_hooksisNone:
self._backward_hooks=OrderedDict()
ifself.grad_fnisnotNone:
self.grad_fn._register_hook_dict(self)
handle=hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id]=hook
returnhandle
defreinforce(self,reward):
"""Registersarewardobtainedasaresultofastochasticprocess.
区分随机节点需要为他们提供reward值。如果图表中包含任何的随机操作,都应该在其输出上调用此函数,否则会出现错误。
Parameters:
reward(Tensor):带有每个元素奖赏的张量,必须与Variable数据的设备位置和形状相匹配。
"""
ifnotisinstance(self.grad_fn,StochasticFunction):
raiseRuntimeError("reinforce()canbeonlycalledonoutputs"
"ofstochasticfunctions")
self.grad_fn._reinforce(reward)
defdetach(self):
"""返回一个从当前图分离出来的心变量。
结果不需要梯度,如果输入是volatile,则输出也是volatile。
..注意::
返回变量使用与原始变量相同的数据张量,并且可以看到其中任何一个的就地修改,并且可能会触发正确性检查中的错误。
"""
result=NoGrad()(self)#thisisneeded,becauseitmergesversioncounters
result._grad_fn=None
returnresult
defdetach_(self):
"""从创建它的图中分离出变量并作为该图的一个叶子"""
self._grad_fn=None
self.requires_grad=False
defretain_grad(self):
"""Enables.gradattributefornon-leafVariables."""
ifself.grad_fnisNone:#no-opforleaves
return
ifnotself.requires_grad:
raiseRuntimeError("can'tretain_gradonVariablethathasrequires_grad=False")
ifhasattr(self,'retains_grad'):
return
weak_self=weakref.ref(self)
defretain_grad_hook(grad):
var=weak_self()
ifvarisNone:
return
ifvar._gradisNone:
var._grad=grad.clone()
else:
var._grad=var._grad+grad
self.register_hook(retain_grad_hook)
self.retains_grad=True
以上这篇Pytorch中retain_graph的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。