pytorch加载语音类自定义数据集的方法教程
前言
pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合
- torch.utils.data.Dataset:所有继承他的子类都应该重写 __len()__ ,__getitem()__这两个方法
- __len()__:返回数据集中数据的数量
- __getitem()__:返回支持下标索引方式获取的一个数据
- torch.utils.data.DataLoader:对数据集进行包装,可以设置batch_size、是否shuffle....
第一步
自定义的Dataset都需要继承torch.utils.data.Dataset类,并且重写它的两个成员方法:
- __len()__:读取数据,返回数据和标签
- __getitem()__:返回数据集的长度
fromtorch.utils.dataimportDataset classAudioDataset(Dataset): def__init__(self,...): """类的初始化""" pass def__getitem__(self,item): """每次怎么读数据,返回数据和标签""" returndata,label def__len__(self): """返回整个数据集的长度""" returntotal
注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本
案例:
文件目录结构
- p225
- ***.wav
- ***.wav
- ***.wav
- ...
- dataset.py
目的:读取p225文件夹中的音频数据
classAudioDataset(Dataset): def__init__(self,data_folder,sr=16000,dimension=8192): self.data_folder=data_folder self.sr=sr self.dim=dimension #获取音频名列表 self.wav_list=[] forroot,dirnames,filenamesinos.walk(data_folder): forfilenameinfnmatch.filter(filenames,"*.wav"):#实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表 self.wav_list.append(os.path.join(root,filename)) def__getitem__(self,item): #读取一个音频文件,返回每个音频数据 filename=self.wav_list[item] wb_wav,_=librosa.load(filename,sr=self.sr) #取帧 iflen(wb_wav)>=self.dim: max_audio_start=len(wb_wav)-self.dim audio_start=np.random.randint(0,max_audio_start) wb_wav=wb_wav[audio_start:audio_start+self.dim] else: wb_wav=np.pad(wb_wav,(0,self.dim-len(wb_wav)),"constant") returnwb_wav,filename def__len__(self): #音频文件的总数 returnlen(self.wav_list)
注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
第二步
实例化Dataset对象
Dataset=AudioDataset("./p225",sr=16000)
如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作
#实例化AudioDataset对象 train_set=AudioDataset("./p225",sr=16000) fori,datainenumerate(train_set): wb_wav,filname=data print(i,wb_wav.shape,filname) ifi==3: break #0(8192,)./p225\p225_001.wav #1(8192,)./p225\p225_002.wav #2(8192,)./p225\p225_003.wav #3(8192,)./p225\p225_004.wav
第三步
如果想要通过batch读取数据,需要使用DataLoader进行包装
为何要使用DataLoader?
- 深度学习的输入是mini_batch形式
- 样本加载时候可能需要随机打乱顺序,shuffle操作
- 样本加载需要采用多线程
pytorch提供的DataLoader封装了上述的功能,这样使用起来更方便。
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,num_workers=0,collate_fn=default_collate,pin_memory=False,drop_last=False)
参数:
- dataset:加载的数据集(Dataset对象)
- batch_size:每个批次要加载多少个样本(默认值:1)
- shuffle:每个epoch是否将数据打乱
- sampler:定义从数据集中抽取样本的策略。如果指定,则不能指定洗牌。
- batch_sampler:类似于sampler,但每次返回一批索引。与batch_size、shuffle、sampler和drop_last相互排斥。
- num_workers:使用多进程加载的进程数,0代表不使用多线程
- collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认拼接方式
- pin_memory:是否将数据保存在pinmemory区,pinmemory中的数据转到GPU会快一些
- drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
返回:数据加载器
案例:
#实例化AudioDataset对象 train_set=AudioDataset("./p225",sr=16000) train_loader=DataLoader(train_set,batch_size=8,shuffle=True) for(i,data)inenumerate(train_loader): wav_data,wav_name=data print(wav_data.shape)#torch.Size([8,8192]) print(i,wav_name) #('./p225\\p225_293.wav','./p225\\p225_156.wav','./p225\\p225_277.wav','./p225\\p225_210.wav', #'./p225\\p225_126.wav','./p225\\p225_021.wav','./p225\\p225_257.wav','./p225\\p225_192.wav')
我们来吃几个栗子消化一下:
栗子1
这个例子就是本文一直举例的,栗子1只是合并了一下而已
文件目录结构
- p225
- ***.wav
- ***.wav
- ***.wav
- ...
- dataset.py
目的:读取p225文件夹中的音频数据
importfnmatch importos importlibrosa importnumpyasnp fromtorch.utils.dataimportDataset fromtorch.utils.dataimportDataLoader classAduio_DataLoader(Dataset): def__init__(self,data_folder,sr=16000,dimension=8192): self.data_folder=data_folder self.sr=sr self.dim=dimension #获取音频名列表 self.wav_list=[] forroot,dirnames,filenamesinos.walk(data_folder): forfilenameinfnmatch.filter(filenames,"*.wav"):#实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表 self.wav_list.append(os.path.join(root,filename)) def__getitem__(self,item): #读取一个音频文件,返回每个音频数据 filename=self.wav_list[item] print(filename) wb_wav,_=librosa.load(filename,sr=self.sr) #取帧 iflen(wb_wav)>=self.dim: max_audio_start=len(wb_wav)-self.dim audio_start=np.random.randint(0,max_audio_start) wb_wav=wb_wav[audio_start:audio_start+self.dim] else: wb_wav=np.pad(wb_wav,(0,self.dim-len(wb_wav)),"constant") returnwb_wav,filename def__len__(self): #音频文件的总数 returnlen(self.wav_list) train_set=Aduio_DataLoader("./p225",sr=16000) train_loader=DataLoader(train_set,batch_size=8,shuffle=True) for(i,data)inenumerate(train_loader): wav_data,wav_name=data print(wav_data.shape)#torch.Size([8,8192]) print(i,wav_name) #('./p225\\p225_293.wav','./p225\\p225_156.wav','./p225\\p225_277.wav','./p225\\p225_210.wav', #'./p225\\p225_126.wav','./p225\\p225_021.wav','./p225\\p225_257.wav','./p225\\p225_192.wav')
注意事项:
- 27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
- 48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意
栗子2
相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。
我给出一个建议,先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num,frame_len,1)的数组保存到h5中。然后用上面讲到的torch.utils.data.Dataset和torch.utils.data.DataLoader读取数据。
具体实现代码:
第一步:创建一个H5_generation脚本用来将数据转换为h5格式文件:
第二步:通过Dataset从h5格式文件中读取数据
importnumpyasnp fromtorch.utils.dataimportDataset fromtorch.utils.dataimportDataLoader importh5py defload_h5(h5_path): #loadtrainingdata withh5py.File(h5_path,'r')ashf: print('Listofarraysininputfile:',hf.keys()) X=np.array(hf.get('data'),dtype=np.float32) Y=np.array(hf.get('label'),dtype=np.float32) returnX,Y classAudioDataset(Dataset): """数据加载器""" def__init__(self,data_folder): self.data_folder=data_folder self.X,self.Y=load_h5(data_folder)#(3392,8192,1) def__getitem__(self,item): #返回一个音频数据 X=self.X[item] Y=self.Y[item] returnX,Y def__len__(self): returnlen(self.X) train_set=AudioDataset("./speaker225_resample_train.h5") train_loader=DataLoader(train_set,batch_size=64,shuffle=True,drop_last=True) for(i,wav_data)inenumerate(train_loader): X,Y=wav_data print(i,X.shape) #0torch.Size([64,8192,1]) #1torch.Size([64,8192,1]) #...
我尝试在__init__中生成h5文件,但是会导致内存爆炸,就很奇怪,因此我只好分开了,
参考
pytorch学习(四)—自定义数据集(讲的比较详细)
总结
到此这篇关于pytorch加载语音类自定义数据集的文章就介绍到这了,更多相关pytorch加载语音类自定义数据集内容请搜索毛票票以前的文章或继续浏览下面的相关文章希望大家以后多多支持毛票票!