PyTorch 解决Dataset和Dataloader遇到的问题
今天在使用PyTorch中Dataset遇到了一个问题。先看代码
classpsDataset(Dataset): def__init__(self,x,y,transforms=None): super(Dataset,self).__init__() self.x=x self.y=y iftransforms==None: self.transforms=Compose([Resize((224,224)),ToTensor()]) else: self.transforms=transforms def__len__(self): returnlen(self.x) def__getitem__(self,idx): img=Image.open(self.x[idx]) img=self.transforms(img) returnimg,torch.tensor([[self.y[idx]]])
结果运行时报错:RuntimeError:invalidargument0:Sizesoftensorsmustmatchexceptindimension0.Got3and1indimension1at/opt/conda/conda-bld/pytorch_1522182087074/work/torch/lib/TH/generic/THTensorMath.c:2897
Google了一下发现是这样的:读入的图片有些是灰度图(1个通道),绝大多数是RGB图片(3通道),也有些是带透明度的(4通道)
。这导致在读入后最后一个维度(通道数)不一致(可能是1、3或者4)。
Dataloader在制作batchdata时,tensor的shape必须一样,就报了这个错误。解决的方法是:img=img.convert(“RGB”)。完
整代码如下:
classpsDataset(Dataset): def__init__(self,x,y,transforms=None): super(Dataset,self).__init__() self.x=x self.y=y iftransforms==None: self.transforms=Compose([Resize((224,224)),ToTensor()]) else: self.transforms=transforms def__len__(self): returnlen(self.x) def__getitem__(self,idx): img=Image.open(self.x[idx]) img=img.convert("RGB") img=self.transforms(img) returnimg,torch.tensor([[self.y[idx]]])
以上这篇PyTorch解决Dataset和Dataloader遇到的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。