使用PyTorch将文件夹下的图片分为训练集和验证集实例
PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
使用这个类的问题在于无法将训练集(trainingdataset)和验证集(validationdataset)分开。我写了两个类来完成这个工作。
importos importtorch fromtorch.utils.dataimportDataset,DataLoader fromtorchvision.transformsimportToTensor,Resize,Compose fromPILimportImage fromsklearn.model_selectionimporttrain_test_split classImageFolderSplitter: #imagesshouldbeplacedinfolderslike: #--root #----root\dogs #----root\dogs\image1.png #----root\dogs\image2.png #----root\cats #----root\cats\image1.png #----root\cats\image2.png #path:therootoftheimagefolder def__init__(self,path,train_size=0.8): self.path=path self.train_size=train_size self.class2num={} self.num2class={} self.class_nums={} self.data_x_path=[] self.data_y_label=[] self.x_train=[] self.x_valid=[] self.y_train=[] self.y_valid=[] forroot,dirs,filesinos.walk(path): iflen(files)==0andlen(dirs)>1: fori,dir1inenumerate(dirs): self.num2class[i]=dir1 self.class2num[dir1]=i eliflen(files)>1andlen(dirs)==0: category="" forkeyinself.class2num.keys(): ifkeyinroot: category=key break label=self.class2num[category] self.class_nums[label]=0 forfile1infiles: self.data_x_path.append(os.path.join(root,file1)) self.data_y_label.append(label) self.class_nums[label]+=1 else: raiseRuntimeError("pleasecheckthefolderstructure!") self.x_train,self.x_valid,self.y_train,self.y_valid=train_test_split(self.data_x_path,self.data_y_label,shuffle=True,train_size=self.train_size) defgetTrainingDataset(self): returnself.x_train,self.y_train defgetValidationDataset(self): returnself.x_valid,self.y_valid classDatasetFromFilename(Dataset): #x:alistofimagefilefullpath #y:alistofimagecategories def__init__(self,x,y,transforms=None): super(DatasetFromFilename,self).__init__() self.x=x self.y=y iftransforms==None: self.transforms=ToTensor() else: self.transforms=transforms def__len__(self): returnlen(self.x) def__getitem__(self,idx): img=Image.open(self.x[idx]) img=img.convert("RGB") returnself.transforms(img),torch.tensor([[self.y[idx]]]) #testcode #splitter=ImageFolderSplitter("for_test") #transforms=Compose([Resize((51,51)),ToTensor()]) #x_train,y_train=splitter.getTrainingDataset() #training_dataset=DatasetFromFilename(x_train,y_train,transforms=transforms) #training_dataloader=DataLoader(training_dataset,batch_size=2,shuffle=True) #x_valid,y_valid=splitter.getValidationDataset() #validation_dataset=DatasetFromFilename(x_valid,y_valid,transforms=transforms) #validation_dataloader=DataLoader(validation_dataset,batch_size=2,shuffle=True) #forx,yintraining_dataloader: #print(x.shape,y.shape)
更多的代码可以在我的Githubreop下找到。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。