pytorch中torch.max和Tensor.view函数用法详解
torch.max()
1.
torch.max()简单来说是返回一个tensor中的最大值。
例如:
>>>si=torch.randn(4,5) >>>print(si) tensor([[1.1659,-1.5195,0.0455,1.7610,-0.2064], [-0.3443,2.0483,0.6303,0.9475,0.4364], [-1.5268,-1.0833,1.6847,0.0145,-0.2088], [-0.8681,0.1516,-0.7764,0.8244,-1.2194]]) >>>print(torch.max(si)) tensor(2.0483)
2.
这个函数的参数中还有一个dim参数,使用方法为re=torch.max(Tensor,dim),返回的re为一个二维向量,其中re[0]为最大值的Tensor,re[1]为最大值对应的index的Tensor。
例如:
>>>print(torch.max(si,0)[0]) tensor([1.1659,2.0483,1.6847,1.7610,0.4364])
注意,Tensor的维度从0开始算起。在torch.max()中指定了dim之后,比如对于一个3x4x5的Tensor,指定dim为0后,得到的结果是维度为0的“每一行”对应位置求最大的那个值,此时输出的Tensor的维度是4x5.
对于简单的二维Tensor,如上面例子的这个4x5的Tensor。指定dim为0,则给出的结果是4行做比较之后的最大值;如果指定dim为1,则给出的结果是5列做比较之后的最大值,且此处做比较时是按照位置分别做比较,得到一个新的Tensor。
Tensor.view()
简单说就是一个把tensor进行reshape的操作。
>>>a=torch.randn(3,4,5,7) >>>b=a.view(1,-1) >>>print(b.size()) torch.Size([1,420])
其中参数-1表示剩下的值的个数一起构成一个维度。如上例中,第一个参数1将第一个维度的大小设定成1,后一个-1就是说第二个维度的大小=元素总数目/第一个维度的大小,此例中为3*4*5*7/1=420.
>>>d=a.view(a.size(0),a.size(1),-1) >>>print(d.size()) torch.Size([3,4,35]) >>>e=a.view(4,-1,5) >>>print(e.size()) torch.Size([4,21,5])
以上这篇pytorch中torch.max和Tensor.view函数用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。