PyTorch加载自己的数据集实例详解
数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力。数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,更会提高模型性能。为解决这一问题,PyTorch提供了几个高效便捷的工具,以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。
数据集存放大致有以下两种方式:
(1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下:root/cat_dog/cat.01.jpg
root/cat_dog/cat.02.jpg
........................
root/cat_dog/dog.01.jpg
root/cat_dog/dog.02.jpg
......................
(2)不同类别的数据集放在不同目录下,目录名就是标签,数据集存放格式如下:
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
................
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
..................
1.1对第1种数据集的处理步骤
(1)生成包含各文件名的列表(List)
(2)定义Dataset的一个子类,该子类需要继承Dataset类,查看Dataset类的源码
(3)重写父类Dataset中的两个魔法方法:一个是:__lent__(self),其功能是len(Dataset),返回Dataset的样本数。另一个是__getitem__(self,index),其功能假设索引为i,使Dataset[i]返回第i个样本。
(4)使用torch.utils.data.DataLoader加载数据集Dataset.
1.2实例详解
以下以cat-dog数据集为例,说明如何实现自定义数据集的加载。
1.2.1数据集结构
所有数据集在cat-dog目录下:
.\cat_dog\cat.01.jpg
.\cat_dog\cat.02.jpg
.\cat_dog\cat.03.jpg
....................
.\cat_dog\dog.01.jpg
.\cat_dog\dog.02.jpg
....................
1.2.2导入需要用到的模块
fromtorch.utils.dataimportDataLoader,Dataset fromskimageimportio,transform importmatplotlib.pyplotasplt importoimporttorch fromtorchvisionimporttransforms,utils fromPILimportImage importpandasaspd importnumpyasnp #过滤警告信息 importwarnings warnings.filterwarnings("ignore")
1.2.3定义加载自定义数据的类
classMyDataset(Dataset):#继承Dataset def__init__(self,path_dir,transform=None):#初始化一些属性 self.path_dir=path_dir#文件路径,如'.\data\cat-dog' self.transform=transform#对图形进行处理,如标准化、截取、转换等 self.images=os.listdir(self.path_dir)#把路径下的所有文件放在一个列表中 def__len__(self):#返回整个数据集的大小 returnlen(self.images) def__getitem__(self,index):#根据索引index返回图像及标签 image_index=self.images[index]#根据索引获取图像文件名称 img_path=os.path.join(self.path_dir,image_index)#获取图像的路径或目录 img=Image.open(img_path).convert('RGB')#读取图像 #根据目录名称获取图像标签(cat或dog) label=img_path.split('\\')[-1].split('.')[0] #把字符转换为数字cat-0,dog-1 label=1if'dog'inlabelelse0 ifself.transformisnotNone: img=self.transform(img) returnimg,label
1.2.4实例化类
dataset=MyDataset('.\data\cat-dog',transform=None) img,label=dataset[0]#将启动魔法方法__getitem__(0) print(type(img))
1.2.5查看图像形状
i=1
forimg,labelindataset:
ifi
img的形状(500,374),label的值0
img的形状(300,280),label的值0
img的形状(489,499),label的值0
img的形状(431,410),label的值0
img的形状(300,224),label的值0
从上面返回样本的形状来看:
(1)每张图片的大小不一样,如果需要取batch训练的神经网络来说很不友好。
(2)返回样本的数值较大,未归一化至[-1,1]
为此需要对img进行转换,如何转换?只要使用torchvision中的transforms即可
1.2.6对图像数据进行处理
这里使用torchvision中的transforms模块
fromtorchvisionimporttransformsasT transform=T.Compose([ T.Resize(224),#缩放图片(Image),保持长宽比不变,最短边为224像素 T.CenterCrop(224),#从图片中间切出224*224的图片 T.ToTensor(),#将图片(Image)转成Tensor,归一化至[0,1] T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])#标准化至[-1,1],规定均值和标准差 ])
1.2.7查看处理后的数据
dataset=MyDataset('.\data\cat-dog',transform=transform) forimg,labelindataset: print("图像img的形状{},标签label的值{}".format(img.shape,label)) print("图像数据预处理后:\n",img) break
图像img的形状torch.Size([3,224,224]),标签label的值0
图像数据预处理后:
tensor([[[0.9059,0.9137,0.9137,...,0.9451,0.9451,0.9451],
[0.9059,0.9137,0.9137,...,0.9451,0.9451,0.9451],
[0.9059,0.9137,0.9137,...,0.9529,0.9529,0.9529],
...,
[-0.4824,-0.5294,-0.5373,...,-0.9216,-0.9294,-0.9451],
[-0.4980,-0.5529,-0.5608,...,-0.9294,-0.9373,-0.9529],
[-0.4980,-0.5529,-0.5686,...,-0.9529,-0.9608,-0.9608]],
[[0.5686,0.5765,0.5765,...,0.7961,0.7882,0.7882],
[0.5686,0.5765,0.5765,...,0.7961,0.7882,0.7882],
[0.5686,0.5765,0.5765,...,0.8039,0.7961,0.7961],
...,
[-0.6078,-0.6471,-0.6549,...,-0.9137,-0.9216,-0.9373],
[-0.6157,-0.6706,-0.6784,...,-0.9216,-0.9294,-0.9451],
[-0.6157,-0.6706,-0.6863,...,-0.9451,-0.9529,-0.9529]],
[[-0.0510,-0.0431,-0.0431,...,0.2078,0.2157,0.2157],
[-0.0510,-0.0431,-0.0431,...,0.2078,0.2157,0.2157],
[-0.0510,-0.0431,-0.0431,...,0.2157,0.2235,0.2235],
...,
[-0.9529,-0.9843,-0.9922,...,-0.9529,-0.9608,-0.9765],
[-0.9686,-0.9922,-1.0000,...,-0.9608,-0.9686,-0.9843],
[-0.9686,-0.9922,-1.0000,...,-0.9843,-0.9922,-0.9922]]])
由此可知,数据已标准化、规范化。
1.2.8对数据集进行批量加载
使用DataLoader模块,对数据集dataset进行批量加载
#使用DataLoader加载数据 dataloader=DataLoader(dataset,batch_size=4,shuffle=True) forbatch_datas,batch_labelsindataloader: print(batch_datas.size(),batch_labels.size()) torch.Size([4,3,224,224])torch.Size([4]) torch.Size([4,3,224,224])torch.Size([4]) torch.Size([4,3,224,224])torch.Size([4]) torch.Size([4,3,224,224])torch.Size([4]) torch.Size([4,3,224,224])torch.Size([4]) torch.Size([4,3,224,224])torch.Size([4]) torch.Size([4,3,224,224])torch.Size([4]) torch.Size([4,3,224,224])torch.Size([4]) torch.Size([4,3,224,224])torch.Size([4]) torch.Size([4,3,224,224])torch.Size([4]) torch.Size([2,3,224,224])torch.Size([2])
1.2.9随机查看一个批次的图像
importtorchvision importmatplotlib.pyplotasplt importnumpyasnp %matplotlibinline #显示图像 defimshow(img): img=img/2+0.5#unnormalize npimg=img.numpy() plt.imshow(np.transpose(npimg,(1,2,0))) plt.show() #随机获取部分训练数据 dataiter=iter(dataloader) images,labels=dataiter.next() #显示图像 imshow(torchvision.utils.make_grid(images)) #打印标签 print(''.join('%s'%["小狗"iflabels[j].item()==1else"小猫"forjinrange(4)]))
2对第2种数据集的处理
处理这种情况比较简单,可分为2步:
(1)使用datasets.ImageFolder读取、处理图像。
(2)使用.data.DataLoader批量加载数据集,示例如下:
importtorch fromtorchvisionimporttransforms,datasets data_transform=transforms.Compose([ transforms.RandomSizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) hymenoptera_dataset=datasets.ImageFolder(root='.\catdog\train', transform=data_transform) dataset_loader=torch.utils.data.DataLoader(hymenoptera_dataset,
总结
到此这篇关于PyTorch加载自己的数据集实例详解的文章就介绍到这了,更多相关PyTorch加载数据集内容请搜索毛票票以前的文章或继续浏览下面的相关文章希望大家以后多多支持毛票票!