Tensorflow之构建自己的图片数据集TFrecords的方法
学习谷歌的深度学习终于有点眉目了,给大家分享我的Tensorflow学习历程。
tensorflow的官方中文文档比较生涩,数据集一直采用的MNIST二进制数据集。并没有过多讲述怎么构建自己的图片数据集tfrecords。
流程是:制作数据集—读取数据集—-加入队列
先贴完整的代码:
#encoding=utf-8
importos
importtensorflowastf
fromPILimportImage
cwd=os.getcwd()
classes={'test','test1','test2'}
#制作二进制数据
defcreate_record():
writer=tf.python_io.TFRecordWriter("train.tfrecords")
forindex,nameinenumerate(classes):
class_path=cwd+"/"+name+"/"
forimg_nameinos.listdir(class_path):
img_path=class_path+img_name
img=Image.open(img_path)
img=img.resize((64,64))
img_raw=img.tobytes()#将图片转化为原生bytes
printindex,img_raw
example=tf.train.Example(
features=tf.train.Features(feature={
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
data=create_record()
#读取二进制数据
defread_and_decode(filename):
#创建文件队列,不限读取的数量
filename_queue=tf.train.string_input_producer([filename])
#createareaderfromfilequeue
reader=tf.TFRecordReader()
#reader从文件队列中读入一个序列化的样本
_,serialized_example=reader.read(filename_queue)
#getfeaturefromserializedexample
#解析符号化的样本
features=tf.parse_single_example(
serialized_example,
features={
'label':tf.FixedLenFeature([],tf.int64),
'img_raw':tf.FixedLenFeature([],tf.string)
}
)
label=features['label']
img=features['img_raw']
img=tf.decode_raw(img,tf.uint8)
img=tf.reshape(img,[64,64,3])
img=tf.cast(img,tf.float32)*(1./255)-0.5
label=tf.cast(label,tf.int32)
returnimg,label
if__name__=='__main__':
if0:
data=create_record("train.tfrecords")
else:
img,label=read_and_decode("train.tfrecords")
print"tengxing",img,label
#使用shuffle_batch可以随机打乱输入next_batch挨着往下取
#shuffle_batch才能实现[img,label]的同步,也即特征和label的同步,不然可能输入的特征和label不匹配
#比如只有这样使用,才能使img和label一一对应,每次提取一个image和对应的label
#shuffle_batch返回的值就是RandomShuffleQueue.dequeue_many()的结果
#Shuffle_batch构建了一个RandomShuffleQueue,并不断地把单个的[img,label],送入队列中
img_batch,label_batch=tf.train.shuffle_batch([img,label],
batch_size=4,capacity=2000,
min_after_dequeue=1000)
#初始化所有的op
init=tf.initialize_all_variables()
withtf.Session()assess:
sess.run(init)
#启动队列
threads=tf.train.start_queue_runners(sess=sess)
foriinrange(5):
printimg_batch.shape,label_batch
val,l=sess.run([img_batch,label_batch])
#l=to_categorical(l,12)
print(val.shape,l)
制作数据集
#制作二进制数据
defcreate_record():
cwd=os.getcwd()
classes={'1','2','3'}
writer=tf.python_io.TFRecordWriter("train.tfrecords")
forindex,nameinenumerate(classes):
class_path=cwd+"/"+name+"/"
forimg_nameinos.listdir(class_path):
img_path=class_path+img_name
img=Image.open(img_path)
img=img.resize((28,28))
img_raw=img.tobytes()#将图片转化为原生bytes
#printindex,img_raw
example=tf.train.Example(
features=tf.train.Features(
feature={
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}
)
)
writer.write(example.SerializeToString())
writer.close()
TFRecords文件包含了tf.train.Example协议内存块(protocolbuffer)(协议内存块包含了字段Features)。我们可以写一段代码获取你的数据,将数据填入到Example协议内存块(protocolbuffer),将协议内存块序列化为一个字符串,并且通过tf.python_io.TFRecordWriter写入到TFRecords文件。
读取数据集
#读取二进制数据
defread_and_decode(filename):
#创建文件队列,不限读取的数量
filename_queue=tf.train.string_input_producer([filename])
#createareaderfromfilequeue
reader=tf.TFRecordReader()
#reader从文件队列中读入一个序列化的样本
_,serialized_example=reader.read(filename_queue)
#getfeaturefromserializedexample
#解析符号化的样本
features=tf.parse_single_example(
serialized_example,
features={
'label':tf.FixedLenFeature([],tf.int64),
'img_raw':tf.FixedLenFeature([],tf.string)
}
)
label=features['label']
img=features['img_raw']
img=tf.decode_raw(img,tf.uint8)
img=tf.reshape(img,[64,64,3])
img=tf.cast(img,tf.float32)*(1./255)-0.5
label=tf.cast(label,tf.int32)
returnimg,label
一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个FloatList,或者ByteList,或者Int64List
加入队列
withtf.Session()assess: sess.run(init) #启动队列 threads=tf.train.start_queue_runners(sess=sess) foriinrange(5): printimg_batch.shape,label_batch val,l=sess.run([img_batch,label_batch]) #l=to_categorical(l,12) print(val.shape,l)
这样就可以的到和tensorflow官方的二进制数据集了,
注意:
- 启动队列那条code不要忘记,不然卡死
- 使用的时候记得使用val和l,不然会报类型错误:TypeError:Thevalueofafeedcannotbeatf.Tensorobject.AcceptablefeedvaluesincludePythonscalars,strings,lists,ornumpyndarrays.
- 算交叉熵时候:cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits,labels)算交叉熵
- 最后评估的时候用tf.nn.in_top_k(logits,labels,1)选logits最大的数的索引和label比较
- cross_entropy=-tf.reduce_sum(y_*tf.log(y_conv))算交叉熵,所以label必须转成one-hot向量
实例2:将图片文件夹下的图片转存tfrecords的数据集。
############################################################################################
#!/usr/bin/python2.7
#-*-coding:utf-8-*-
#Author:zhaoqinghui
#Date:2016.5.10
#Function:imageconverttotfrecords
#############################################################################################
importtensorflowastf
importnumpyasnp
importcv2
importos
importos.path
fromPILimportImage
#参数设置
###############################################################################################
train_file='train.txt'#训练图片
name='train'#生成train.tfrecords
output_directory='./tfrecords'
resize_height=32#存储图片高度
resize_width=32#存储图片宽度
###############################################################################################
def_int64_feature(value):
returntf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def_bytes_feature(value):
returntf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
defload_file(examples_list_file):
lines=np.genfromtxt(examples_list_file,delimiter="",dtype=[('col1','S120'),('col2','i8')])
examples=[]
labels=[]
forexample,labelinlines:
examples.append(example)
labels.append(label)
returnnp.asarray(examples),np.asarray(labels),len(lines)
defextract_image(filename,resize_height,resize_width):
image=cv2.imread(filename)
image=cv2.resize(image,(resize_height,resize_width))
b,g,r=cv2.split(image)
rgb_image=cv2.merge([r,g,b])
returnrgb_image
deftransform2tfrecord(train_file,name,output_directory,resize_height,resize_width):
ifnotos.path.exists(output_directory)oros.path.isfile(output_directory):
os.makedirs(output_directory)
_examples,_labels,examples_num=load_file(train_file)
filename=output_directory+"/"+name+'.tfrecords'
writer=tf.python_io.TFRecordWriter(filename)
fori,[example,label]inenumerate(zip(_examples,_labels)):
print('No.%d'%(i))
image=extract_image(example,resize_height,resize_width)
print('shape:%d,%d,%d,label:%d'%(image.shape[0],image.shape[1],image.shape[2],label))
image_raw=image.tostring()
example=tf.train.Example(features=tf.train.Features(feature={
'image_raw':_bytes_feature(image_raw),
'height':_int64_feature(image.shape[0]),
'width':_int64_feature(image.shape[1]),
'depth':_int64_feature(image.shape[2]),
'label':_int64_feature(label)
}))
writer.write(example.SerializeToString())
writer.close()
defdisp_tfrecords(tfrecord_list_file):
filename_queue=tf.train.string_input_producer([tfrecord_list_file])
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)
features=tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'height':tf.FixedLenFeature([],tf.int64),
'width':tf.FixedLenFeature([],tf.int64),
'depth':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)
}
)
image=tf.decode_raw(features['image_raw'],tf.uint8)
#print(repr(image))
height=features['height']
width=features['width']
depth=features['depth']
label=tf.cast(features['label'],tf.int32)
init_op=tf.initialize_all_variables()
resultImg=[]
resultLabel=[]
withtf.Session()assess:
sess.run(init_op)
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
foriinrange(21):
image_eval=image.eval()
resultLabel.append(label.eval())
image_eval_reshape=image_eval.reshape([height.eval(),width.eval(),depth.eval()])
resultImg.append(image_eval_reshape)
pilimg=Image.fromarray(np.asarray(image_eval_reshape))
pilimg.show()
coord.request_stop()
coord.join(threads)
sess.close()
returnresultImg,resultLabel
defread_tfrecord(filename_queuetemp):
filename_queue=tf.train.string_input_producer([filename_queuetemp])
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)
features=tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'width':tf.FixedLenFeature([],tf.int64),
'depth':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)
}
)
image=tf.decode_raw(features['image_raw'],tf.uint8)
#image
tf.reshape(image,[256,256,3])
#normalize
image=tf.cast(image,tf.float32)*(1./255)-0.5
#label
label=tf.cast(features['label'],tf.int32)
returnimage,label
deftest():
transform2tfrecord(train_file,name,output_directory,resize_height,resize_width)#转化函数
img,label=disp_tfrecords(output_directory+'/'+name+'.tfrecords')#显示函数
img,label=read_tfrecord(output_directory+'/'+name+'.tfrecords')#读取函数
printlabel
if__name__=='__main__':
test()
这样就可以得到自己专属的数据集.tfrecords了 ,它可以直接用于tensorflow的数据集。
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。