tensorflow中next_batch的具体使用
本文介绍了tensorflow中next_batch的具体使用,分享给大家,具体如下:
此处给出了几种不同的next_batch方法,该文章只是做出代码片段的解释,以备以后查看:
defnext_batch(self,batch_size,fake_data=False): """Returnthenext`batch_size`examplesfromthisdataset.""" iffake_data: fake_image=[1]*784 ifself.one_hot: fake_label=[1]+[0]*9 else: fake_label=0 return[fake_imagefor_inxrange(batch_size)],[ fake_labelfor_inxrange(batch_size) ] start=self._index_in_epoch self._index_in_epoch+=batch_size ifself._index_in_epoch>self._num_examples:#epoch中的句子下标是否大于所有语料的个数,如果为True,开始新一轮的遍历 #Finishedepoch self._epochs_completed+=1 #Shufflethedata perm=numpy.arange(self._num_examples)#arange函数用于创建等差数组 numpy.random.shuffle(perm)#打乱 self._images=self._images[perm] self._labels=self._labels[perm] #Startnextepoch start=0 self._index_in_epoch=batch_size assertbatch_size<=self._num_examples end=self._index_in_epoch returnself._images[start:end],self._labels[start:end]
该段代码摘自mnist.py文件,从代码第12行start=self._index_in_epoch开始解释,_index_in_epoch-1是上一次batch个图片中最后一张图片的下边,这次epoch第一张图片的下标是从_index_in_epoch开始,最后一张图片的下标是_index_in_epoch+batch,如果_index_in_epoch大于语料中图片的个数,表示这个epoch是不合适的,就算是完成了语料的一遍的遍历,所以应该对图片洗牌然后开始新一轮的语料组成batch开始
defptb_iterator(raw_data,batch_size,num_steps): """IterateontherawPTBdata. Thisgeneratesbatch_sizepointersintotherawPTBdata,andallows minibatchiterationalongthesepointers. Args: raw_data:oneoftherawdataoutputsfromptb_raw_data. batch_size:int,thebatchsize. num_steps:int,thenumberofunrolls. Yields: Pairsofthebatcheddata,eachamatrixofshape[batch_size,num_steps]. Thesecondelementofthetupleisthesamedatatime-shiftedtothe rightbyone. Raises: ValueError:ifbatch_sizeornum_stepsaretoohigh. """ raw_data=np.array(raw_data,dtype=np.int32) data_len=len(raw_data) batch_len=data_len//batch_size#有多少个batch data=np.zeros([batch_size,batch_len],dtype=np.int32)#batch_len有多少个单词 foriinrange(batch_size):#batch_size有多少个batch data[i]=raw_data[batch_len*i:batch_len*(i+1)] epoch_size=(batch_len-1)//num_steps#batch_len是指一个batch中有多少个句子 #epoch_size=((len(data)//model.batch_size)-1)//model.num_steps#//表示整数除法 ifepoch_size==0: raiseValueError("epoch_size==0,decreasebatch_sizeornum_steps") foriinrange(epoch_size): x=data[:,i*num_steps:(i+1)*num_steps] y=data[:,i*num_steps+1:(i+1)*num_steps+1] yield(x,y)
第三种方式:
defnext(self,batch_size): """Returnabatchofdata.Whendatasetendisreached,startover. """ ifself.batch_id==len(self.data): self.batch_id=0 batch_data=(self.data[self.batch_id:min(self.batch_id+ batch_size,len(self.data))]) batch_labels=(self.labels[self.batch_id:min(self.batch_id+ batch_size,len(self.data))]) batch_seqlen=(self.seqlen[self.batch_id:min(self.batch_id+ batch_size,len(self.data))]) self.batch_id=min(self.batch_id+batch_size,len(self.data)) returnbatch_data,batch_labels,batch_seqlen
第四种方式:
defbatch_iter(sourceData,batch_size,num_epochs,shuffle=True): data=np.array(sourceData)#将sourceData转换为array存储 data_size=len(sourceData) num_batches_per_epoch=int(len(sourceData)/batch_size)+1 forepochinrange(num_epochs): #Shufflethedataateachepoch ifshuffle: shuffle_indices=np.random.permutation(np.arange(data_size)) shuffled_data=sourceData[shuffle_indices] else: shuffled_data=sourceData forbatch_numinrange(num_batches_per_epoch): start_index=batch_num*batch_size end_index=min((batch_num+1)*batch_size,data_size) yieldshuffled_data[start_index:end_index]
迭代器的用法,具体学习Python迭代器的用法
另外需要注意的是,前三种方式只是所有语料遍历一次,而最后一种方法是,所有语料遍历了num_epochs次
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。