pytorch sampler对数据进行采样的实现
PyTorch中还单独提供了一个sampler模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法:WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。
构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。
下面举例说明。
fromdataSetimport* dataset=DogCat('data/dogcat/',transform=transform) fromtorch.utils.dataimportDataLoader #狗的图片被取出的概率是猫的概率的两倍 #两类图片被取出的概率与weights的绝对大小无关,只和比值有关 weights=[2iflabel==1else1fordata,labelindataset] print(weights) fromtorch.utils.data.samplerimportWeightedRandomSampler sampler=WeightedRandomSampler(weights,\ num_samples=9,\ replacement=True) dataloader=DataLoader(dataset, batch_size=3, sampler=sampler) fordatas,labelsindataloader: print(labels.tolist())
输出:
[2,2,1,1,2,1,1,2] [1,1,0] [1,0,0] [0,0,1]
github地址:
https://github.com/WebLearning17/CommonTool
以上这篇pytorchsampler对数据进行采样的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。