python生成tensorflow输入输出的图像格式的方法
TensorFLow能够识别的图像文件,可以通过numpy,使用tf.Variable或者tf.placeholder加载进tensorflow;也可以通过自带函数(tf.read)读取,当图像文件过多时,一般使用pipeline通过队列的方法进行读取。下面我们介绍两种生成tensorflow的图像格式的方法,供给tensorflow的graph的输入与输出。
importcv2
importnumpyasnp
importh5py
height=460
width=345
withh5py.File('make3d_dataset_f460.mat','r')asf:
images=f['images'][:]
image_num=len(images)
data=np.zeros((image_num,height,width,3),np.uint8)
data=images.transpose((0,3,2,1))
先生成图像文件的路径:ls*.jpg>list.txt
importcv2 importnumpyasnp image_path='./' list_file='list.txt' height=48 width=48 image_name_list=[]#readimage withopen(image_path+list_file)asfid: image_name_list=[x.strip()forxinfid.readlines()] image_num=len(image_name_list) data=np.zeros((image_num,height,width,3),np.uint8) foridxinrange(image_num): img=cv2.imread(image_name_list[idx]) img=cv2.resize(img,(height,width)) data[idx,:,:,:]=img
2Tensorflow自带函数读取
defget_image(image_path): """Readsthejpgimagefromimage_path. Returnstheimageasatf.float32tensor Args: image_path:tf.stringtensor Reuturn: thedecodedjpegimagecastedtofloat32 """ returntf.image.convert_image_dtype( tf.image.decode_jpeg( tf.read_file(image_path),channels=3), dtype=tf.uint8)
pipeline读取方法
#Exampleonhowtousethetensorflowinputpipelines.Theexplanationcanbefoundhereischlag.github.io.
importtensorflowastf
importrandom
fromtensorflow.python.frameworkimportops
fromtensorflow.python.frameworkimportdtypes
dataset_path="/path/to/your/dataset/mnist/"
test_labels_file="test-labels.csv"
train_labels_file="train-labels.csv"
test_set_size=5
IMAGE_HEIGHT=28
IMAGE_WIDTH=28
NUM_CHANNELS=3
BATCH_SIZE=5
defencode_label(label):
returnint(label)
defread_label_file(file):
f=open(file,"r")
filepaths=[]
labels=[]
forlineinf:
filepath,label=line.split(",")
filepaths.append(filepath)
labels.append(encode_label(label))
returnfilepaths,labels
#readinglabelsandfilepath
train_filepaths,train_labels=read_label_file(dataset_path+train_labels_file)
test_filepaths,test_labels=read_label_file(dataset_path+test_labels_file)
#transformrelativepathintofullpath
train_filepaths=[dataset_path+fpforfpintrain_filepaths]
test_filepaths=[dataset_path+fpforfpintest_filepaths]
#forthisexamplewewillcreateorowntestpartition
all_filepaths=train_filepaths+test_filepaths
all_labels=train_labels+test_labels
all_filepaths=all_filepaths[:20]
all_labels=all_labels[:20]
#convertstringintotensors
all_images=ops.convert_to_tensor(all_filepaths,dtype=dtypes.string)
all_labels=ops.convert_to_tensor(all_labels,dtype=dtypes.int32)
#createapartitionvector
partitions=[0]*len(all_filepaths)
partitions[:test_set_size]=[1]*test_set_size
random.shuffle(partitions)
#partitionourdataintoatestandtrainsetaccordingtoourpartitionvector
train_images,test_images=tf.dynamic_partition(all_images,partitions,2)
train_labels,test_labels=tf.dynamic_partition(all_labels,partitions,2)
#createinputqueues
train_input_queue=tf.train.slice_input_producer(
[train_images,train_labels],
shuffle=False)
test_input_queue=tf.train.slice_input_producer(
[test_images,test_labels],
shuffle=False)
#processpathandstringtensorintoanimageandalabel
file_content=tf.read_file(train_input_queue[0])
train_image=tf.image.decode_jpeg(file_content,channels=NUM_CHANNELS)
train_label=train_input_queue[1]
file_content=tf.read_file(test_input_queue[0])
test_image=tf.image.decode_jpeg(file_content,channels=NUM_CHANNELS)
test_label=test_input_queue[1]
#definetensorshape
train_image.set_shape([IMAGE_HEIGHT,IMAGE_WIDTH,NUM_CHANNELS])
test_image.set_shape([IMAGE_HEIGHT,IMAGE_WIDTH,NUM_CHANNELS])
#collectbatchesofimagesbeforeprocessing
train_image_batch,train_label_batch=tf.train.batch(
[train_image,train_label],
batch_size=BATCH_SIZE
#,num_threads=1
)
test_image_batch,test_label_batch=tf.train.batch(
[test_image,test_label],
batch_size=BATCH_SIZE
#,num_threads=1
)
print"inputpipelineready"
withtf.Session()assess:
#initializethevariables
sess.run(tf.initialize_all_variables())
#initializethequeuethreadstostarttoshoveldata
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(coord=coord)
print"fromthetrainset:"
foriinrange(20):
printsess.run(train_label_batch)
print"fromthetestset:"
foriinrange(10):
printsess.run(test_label_batch)
#stopourqueuethreadsandproperlyclosethesession
coord.request_stop()
coord.join(threads)
sess.close()
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。