pytorch::Dataloader中的迭代器和生成器应用详解
在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader。
为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭代器和生成器。
这一应用场景正是python中迭代器模式的意义所在,因此本文对Dataloader中代码进行解读,可以更好的理解python中迭代器和生成器的概念。
本文的内容主要有:
- 解释python中的迭代器和生成器概念
- 解读pytorch中Dataloader代码,如何使用迭代器和生成器实现数据加载
python迭代基础
python中围绕着迭代有以下概念:
- 可迭代对象iterables
- 迭代器iterator
- 生成器generator
这三个概念互相关联,并不是孤立的。在可迭代对象的基础上发展了迭代器,在迭代器的基础上又发展了生成器。
学习这些概念的名词解释没有多大意义。编程中很多的抽象概念都是为了更好的实现某些功能,才去人为创造的协议和模式。
因此,要理解它们,需要探究概念背后的逻辑,为什么这样设计?要解决的真正问题是什么?在哪些场景下应用是最好的?
迭代模式首先要解决的基础问题是,需要按一定顺序获取集合内部数据,比如循环某个list。
当数据很小时,不会有问题。但当读取大量数据时,一次性读取会超出内存限制,因此想出以下方法:
- 把大的数据分成几个小块,分批处理
- 惰性的取值方式,按需取值
循环读数据可分为下面三种应用场景,对应着容器(可迭代对象),迭代器和生成器:
- forxincontainer:为了遍历python内部序列容器(如list),这些类型内部实现了__getitem__()方法,可以从0开始按顺序遍历序列容器中的元素。
- forxiniterator:为了循环用户自定义的迭代器,需要实现__iter__和__next__方法,__iter__是迭代协议,具体每次迭代的执行逻辑在__next__或next方法里
- forxingenerator:为了节省循环的内存和加速,使用生成器来实现惰性加载,在迭代器的基础上加入了yield语句,最简单的例子是range(5)
代码示例:
#普通循环forxinlist numbers=[1,2,3,] forninnumbers: print(n)#1,2,3 #for循环实际干的事情 #iter输入一个可迭代对象list,返回迭代器 #next方法取数据 my_iterator=iter(numbers) next(my_iterator)#1 next(my_iterator)#2 next(my_iterator)#3 next(my_iterator)#StopIterationexception #迭代器循环forxiniterator fori,ninenumerate(numbers): print(i,n)#0,1/1,3/2,3 #生成器循环forxingenerator foriinrange(3): print(i)#0,1,2
上面示例代码中python内置函数iter和next的用法:
- iter函数,调用__iter__,返回一个迭代器
- next函数,输入迭代器,调用__next__,取出数据
比较容易混淆的是__iter__和__next__两个方法。它们的区别是:
- __iter__是为了可以迭代,真正执行取数据的逻辑是__next__方法实现的,实际调用是通过next(iterator)完成
- __iter__可以返回自身(returnself),实际读取数据的实现放在__next__方法
- __iter__可以和yield搭配,返回生成器对象
__iter__返回自身的做法有点类似python中的类型系统。为了保持一致性,python中一切皆对象。
每个对象创建后,都有类型指针,而类型对象的指针指向元对象,元对象的指针指向自身。
生成器,是在__iter__方法中加入yield语句,好处有:
- 减少循环判断逻辑的复杂度
- 惰性取值,节省内存和时间
yield作用:
- 代替函数中的return语句
- 记住上一次循环迭代器内部元素的位置
三种循环模式常用函数
forxincontainer方法:
- list,deque,…
- set,frozensets,…
- dict,defaultdict,OrderedDict,Counter,…
- tuple,namedtuple,…
- str
forxiniterator方法:
- enumerate()#加上list的index
- sorted()#排序list
- reversed()#倒序list
- zip()#合并list
forxingenerator方法:
- range()
- map()
- filter()
- reduce()
- [xforxinlist(...)]
Dataloder源码分析
pytorch采用 forxiniterator模式,从Dataloader类中读取数据。
- 为了实现该迭代模式,在Dataloader内部实现__iter__方法,实际返回的是_DataLoaderIter类。
- _DataLoaderIter类里面,实现了__iter__方法,返回自身,具体执行读数据的逻辑,在__next__方法中。
以下代码只截取了单线程下的数据读取。
classDataLoader(object): r""" Dataloader.Combinesadatasetandasampler,andprovides single-ormulti-processiteratorsoverthedataset. """ def__init__(self,dataset,batch_size=1,shuffle=False,...): self.dataset=dataset self.batch_sampler=batch_sampler ... def__iter__(self): return_DataLoaderIter(self) def__len__(self): returnlen(self.batch_sampler) class_DataLoaderIter(object): r"""IteratesonceovertheDataLoader'sdataset,asspecifiedbythesampler""" def__init__(self,loader): self.sample_iter=iter(self.batch_sampler) ... def__next__(self): ifself.num_workers==0:#same-processloading indices=next(self.sample_iter)#mayraiseStopIteration batch=self.collate_fn([self.dataset[i]foriinindices]) ifself.pin_memory: batch=pin_memory_batch(batch) returnbatch ... def__iter__(self): returnself
Dataloader类中读取数据Index的方法,采用了 forxingenerator方式,但是调用采用iter和next函数
- 构建随机采样类RandomSampler,内部实现了__iter__方法
- __iter__方法内部使用了yield,循环遍历数据集,当数量达到batch_size大小时,就返回
- 实例化随机采样类,传入iter函数,返回一个迭代器
- next会调用随机采样类中生成器,返回相应的index数据
classRandomSampler(object): """randomsamplertoyieldamini-batchofindices.""" def__init__(self,batch_size,dataset,drop_last=False): self.dataset=dataset self.batch_size=batch_size self.num_imgs=len(dataset) self.drop_last=drop_last def__iter__(self): indices=np.random.permutation(self.num_imgs) batch=[] foriinindices: batch.append(i) iflen(batch)==self.batch_size: yieldbatch batch=[] ##ifimagesnottoyieldabatch iflen(batch)>0andnotself.drop_last: yieldbatch def__len__(self): ifself.drop_last: returnself.num_imgs//self.batch_size else: return(self.num_imgs+self.batch_size-1)//self.batch_size batch_sampler=RandomSampler(batch_size.dataset) sample_iter=iter(batch_sampler) indices=next(sample_iter)
总结
本文总结了python中循环的三种模式:
- forxincontainer可迭代对象
- forxiniterator迭代器
- forxingenerator生成器
pytorch中的数据加载模块Dataloader,使用生成器来返回数据的索引,使用迭代器来返回需要的张量数据,可以在大量数据情况下,实现小批量循环迭代式的读取,避免了内存不足问题。
参考文章
迭代器和生成器
流畅的Python-第14章:可迭代的对象、迭代器和生成器
pytorch-dataloader源码
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。