PyTorch笔记之scatter()函数的使用
scatter()和scatter_()的作用是一样的,只不过scatter()不会直接修改原来的Tensor,而scatter_()会
PyTorch中,一般函数加下划线代表直接在原来的Tensor上修改
scatter(dim,index,src)的参数有3个
- dim:沿着哪个维度进行索引
- index:用来scatter的元素索引
- src:用来scatter的源元素,可以是一个标量或一个张量
这个scatter可以理解成放置元素或者修改元素
简单说就是通过一个张量src来修改另一个张量,哪个元素需要修改、用src中的哪个元素来修改由dim和index决定
官方文档给出了3维张量的具体操作说明,如下所示
self[index[i][j][k]][j][k]=src[i][j][k]#ifdim==0 self[i][index[i][j][k]][k]=src[i][j][k]#ifdim==1 self[i][j][index[i][j][k]]=src[i][j][k]#ifdim==2
exmaple:
x=torch.rand(2,5) #tensor([[0.1940,0.3340,0.8184,0.4269,0.5945], #[0.2078,0.5978,0.0074,0.0943,0.0266]]) torch.zeros(3,5).scatter_(0,torch.tensor([[0,1,2,0,0],[2,0,0,1,2]]),x) #tensor([[0.1940,0.5978,0.0074,0.4269,0.5945], #[0.0000,0.3340,0.0000,0.0943,0.0000], #[0.2078,0.0000,0.8184,0.0000,0.0266]])
具体地说,我们的index是torch.tensor([[0,1,2,0,0],[2,0,0,1,2]]),一个二维张量,下面用图简单说明
我们是2维张量,一开始进行$self[index[0][0]][0]$,其中$index[0][0]$的值是0,所以执行$self[0][0]=x[0][0]=0.1940$
$self[index[i][j]][j]=src[i][j]$
再比如$self[index[1][0]][0]$,其中$index[1][0]$的值是2,所以执行$self[2][0]=x[1][0]=0.2078$
src除了可以是张量外,也可以是一个标量
example:
torch.zeros(3,5).scatter_(0,torch.tensor([[0,1,2,0,0],[2,0,0,1,2]]),7) #tensor([[7.,7.,7.,7.,7.], #[0.,7.,0.,7.,0.], #[7.,0.,7.,0.,7.]]
scatter()一般可以用来对标签进行one-hot编码,这就是一个典型的用标量来修改张量的一个例子
example:
class_num=10 batch_size=4 label=torch.LongTensor(batch_size,1).random_()%class_num #tensor([[6], #[0], #[3], #[2]]) torch.zeros(batch_size,class_num).scatter_(1,label,1) #tensor([[0.,0.,0.,0.,0.,0.,1.,0.,0.,0.], #[1.,0.,0.,0.,0.,0.,0.,0.,0.,0.], #[0.,0.,0.,1.,0.,0.,0.,0.,0.,0.], #[0.,0.,1.,0.,0.,0.,0.,0.,0.,0.]])
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。