pytorch ImageFolder的覆写实例
在为数据分类训练分类器的时候,比如猫狗分类时,我们经常会使用pytorch的ImageFolder:
CLASStorchvision.datasets.ImageFolder(root,transform=None,target_transform=None,loader=
,is_valid_file=None)
使用可见pytorchtorchvision.ImageFolder的用法介绍
这里想实现的是如果想要覆写该函数,即能使用它的特性,又可以实现自己的功能
首先先分析下其源代码:
IMG_EXTENSIONS=['.jpg','.jpeg','.png','.ppm','.bmp','.pgm','.tif','.tiff','webp'] classImageFolder(DatasetFolder): """Agenericdataloaderwheretheimagesarearrangedinthisway::: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root(string):Rootdirectorypath. transform(callable,optional):Afunction/transformthattakesinanPILimage andreturnsatransformedversion.E.g,``transforms.RandomCrop`` target_transform(callable,optional):Afunction/transformthattakesinthe targetandtransformsit. loader(callable,optional):Afunctiontoloadanimagegivenitspath. Attributes: classes(list):Listoftheclassnames. class_to_idx(dict):Dictwithitems(class_name,class_index). imgs(list):Listof(imagepath,class_index)tuples """ def__init__(self,root,transform=None,target_transform=None, loader=default_loader): super(ImageFolder,self).__init__(root,loader,IMG_EXTENSIONS, transform=transform, target_transform=target_transform) self.imgs=self.samples
ImageFolder的代码很简单,主要是继承了DatasetFolder:
defhas_file_allowed_extension(filename,extensions): """查看文件是否是支持的可扩展类型 Args: filename(string):文件路径 extensions(iterableofstrings):可扩展类型列表,即能接受的图像文件类型 Returns: bool:Trueifthefilenameendswithoneofgivenextensions """ filename_lower=filename.lower() returnany(filename_lower.endswith(ext)forextinextensions)#返回True或False列表 defmake_dataset(dir,class_to_idx,extensions): """ 返回形如[(图像路径,该图像对应的类别索引值),(),...] """ images=[] dir=os.path.expanduser(dir) fortargetinsorted(class_to_idx.keys()): d=os.path.join(dir,target) ifnotos.path.isdir(d): continue forroot,_,fnamesinsorted(os.walk(d)):#层层遍历文件夹,返回当前文件夹路径,存在的所有文件夹名,存在的所有文件名 forfnameinsorted(fnames): ifhas_file_allowed_extension(fname,extensions):查看文件是否是支持的可扩展类型,是则继续 path=os.path.join(root,fname) item=(path,class_to_idx[target]) images.append(item) returnimages classDatasetFolder(data.Dataset): """Agenericdataloaderwherethesamplesarearrangedinthisway::: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext Args: root(string):根目录路径 loader(callable):根据给定的路径来加载样本的可调用函数 extensions(list[string]):可扩展类型列表,即能接受的图像文件类型. transform(callable,optional):用于样本的transform函数,然后返回样本transform后的版本 E.g,``transforms.RandomCrop``forimages. target_transform(callable,optional):用于样本标签的transform函数 Attributes: classes(list):类别名列表 class_to_idx(dict):项目(class_name,class_index)字典,如{'cat':0,'dog':1} samples(list):(samplepath,class_index)元组列表,即(样本路径,类别索引) targets(list):在数据集中每张图片的类索引值,为列表 """ def__init__(self,root,loader,extensions,transform=None,target_transform=None): classes,class_to_idx=self._find_classes(root)#得到类名和类索引,如['cat','dog']和{'cat':0,'dog':1} #返回形如[(图像路径,该图像对应的类别索引值),(),...],即对每个图像进行标记 samples=make_dataset(root,class_to_idx,extensions) iflen(samples)==0: raise(RuntimeError("Found0filesinsubfoldersof:"+root+"\n" "Supportedextensionsare:"+",".join(extensions))) self.root=root self.loader=loader self.extensions=extensions self.classes=classes self.class_to_idx=class_to_idx self.samples=samples self.targets=[s[1]forsinsamples]#所有图像的类索引值组成的列表 self.transform=transform self.target_transform=target_transform def_find_classes(self,dir): """ 在数据集中查找类文件夹。 Args: dir(string):根目录路径 Returns: 返回元组:(classes,class_to_idx)即(类名,类索引),其中classes即相应的目录名,如['cat','dog'];class_to_idx为形如{类名:类索引}的字典,如{'cat':0,'dog':1}. Ensures: 保证没有类名是另一个类目录的子目录 """ ifsys.version_info>=(3,5): #FasterandavailableinPython3.5andabove classes=[d.namefordinos.scandir(dir)ifd.is_dir()]#获得根目录dir的所有第一层子目录名 else: classes=[dfordinos.listdir(dir)ifos.path.isdir(os.path.join(dir,d))]#效果和上面的一样,只是版本不同方法不同 classes.sort()#然后对类名进行排序 class_to_idx={classes[i]:iforiinrange(len(classes))}#然后将类名和索引值一一对应的到相应字典,如{'cat':0,'dog':1} returnclasses,class_to_idx#然后返回类名和类索引 def__getitem__(self,index): """ Args: index(int):Index Returns: tuple:(sample,target)wheretargetisclass_indexofthetargetclass. """ path,target=self.samples[index] sample=self.loader(path)#加载图片 ifself.transformisnotNone: sample=self.transform(sample) ifself.target_transformisnotNone: target=self.target_transform(target) returnsample,target def__len__(self): returnlen(self.samples) def__repr__(self): fmt_str='Dataset'+self.__class__.__name__+'\n' fmt_str+='Numberofdatapoints:{}\n'.format(self.__len__()) fmt_str+='RootLocation:{}\n'.format(self.root) tmp='Transforms(ifany):' fmt_str+='{0}{1}\n'.format(tmp,self.transform.__repr__().replace('\n','\n'+''*len(tmp))) tmp='TargetTransforms(ifany):' fmt_str+='{0}{1}'.format(tmp,self.target_transform.__repr__().replace('\n','\n'+''*len(tmp))) returnfmt_str
此时想要覆写ImageFolder,代码为:
classCustomImageFolder(ImageFolder): """ 为了得到两张图(其中一张是随机选取的)的图像和索引值信息 """ def__init__(self,root,transform=None): super(CustomImageFolder,self).__init__(root,transform) self.indices=range(len(self))#该文件夹中的长度 def__getitem__(self,index1): index2=random.choice(self.indices)#从[0,indices]中随机抽取一个数字,为了随机选取一张图 path1=self.imgs[index1][0]#此时的self.imgs等于self.samples,即内容为[(图像路径,该图像对应的类别索引值),(),...] label1=self.imgs[index1][1] path2=self.imgs[index2][0] label2=self.imgs[index2][1] img1=self.loader(path1) img2=self.loader(path2) ifself.transformisnotNone: img1=self.transform(img1) img2=self.transform(img2) returnimg1,img2,label1,label2
以上这篇pytorchImageFolder的覆写实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。