python使用tensorflow保存、加载和使用模型的方法
使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:
http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
我对这篇文章进行了整理和汇总。
首先是模型的保存。直接上代码:
#!/usr/bin/envpython
#-*-coding:utf-8-*-
############################
#FileName:tut1_save.py
#Author:Wang
#Mail:wang19920419@hotmail.com
#CreatedTime:2017-08-3011:04:25
############################
importtensorflowastf
#preparetofeedinput,i.e.feed_dictandplaceholders
w1=tf.Variable(tf.random_normal(shape=[2]),name='w1')#nameisveryimportantinrestoration
w2=tf.Variable(tf.random_normal(shape=[2]),name='w2')
b1=tf.Variable(2.0,name='bias1')
feed_dict={w1:[10,3],w2:[5,5]}
#defineatestoperationthatwillberestored
w3=tf.add(w1,w2)#withoutname,w3willnotbestored
w4=tf.multiply(w3,b1,name="op_to_restore")
#saver=tf.train.Saver()
saver=tf.train.Saver(max_to_keep=4,keep_checkpoint_every_n_hours=1)
sess=tf.Session()
sess.run(tf.global_variables_initializer())
printsess.run(w4,feed_dict)
#saver.save(sess,'my_test_model',global_step=100)
saver.save(sess,'my_test_model')
#saver.save(sess,'my_test_model',global_step=100,write_meta_graph=False)
需要说明的有以下几点:
1.创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。
2.saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph=False加以限制。
3.这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。
下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess,'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:
#!/usr/bin/envpython
#-*-coding:utf-8-*-
############################
#FileName:tut2_import.py
#Author:Wang
#Mail:wang19920419@hotmail.com
#CreatedTime:2017-08-3014:16:38
############################
importtensorflowastf
sess=tf.Session()
new_saver=tf.train.import_meta_graph('my_test_model.meta')
new_saver.restore(sess,tf.train.latest_checkpoint('./'))
printsess.run('w1:0')
使用加载的模型,输入新数据,计算输出,还是直接上代码:
#!/usr/bin/envpython
#-*-coding:utf-8-*-
############################
#FileName:tut3_reuse.py
#Author:Wang
#Mail:wang19920419@hotmail.com
#CreatedTime:2017-08-3014:33:35
############################
importtensorflowastf
sess=tf.Session()
#First,loadmetagraphandrestoreweights
saver=tf.train.import_meta_graph('my_test_model.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
#Second,accessandcreateplaceholdersvariablesandcreatefeed_dicttofeednewdata
graph=tf.get_default_graph()
w1=graph.get_tensor_by_name('w1:0')
w2=graph.get_tensor_by_name('w2:0')
feed_dict={w1:[-1,1],w2:[4,6]}
#Accesstheopthatwanttorun
op_to_restore=graph.get_tensor_by_name('op_to_restore:0')
printsess.run(op_to_restore,feed_dict)#ouotput:[6.14.]
在已经加载的网络后继续加入新的网络层:
importtensorflowastf
sess=tf.Session()
#Firstlet'sloadmetagraphandrestoreweights
saver=tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
#Now,let'saccessandcreateplaceholdersvariablesand
#createfeed-dicttofeednewdata
graph=tf.get_default_graph()
w1=graph.get_tensor_by_name("w1:0")
w2=graph.get_tensor_by_name("w2:0")
feed_dict={w1:13.0,w2:17.0}
#Now,accesstheopthatyouwanttorun.
op_to_restore=graph.get_tensor_by_name("op_to_restore:0")
#Addmoretothecurrentgraph
add_on_op=tf.multiply(op_to_restore,2)
printsess.run(add_on_op,feed_dict)
#Thiswillprint120.
对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):
......
......
saver=tf.train.import_meta_graph('vgg.meta')
#Accessthegraph
graph=tf.get_default_graph()
##Preparethefeed_dictforfeedingdataforfine-tuning
#Accesstheappropriateoutputforfine-tuning
fc7=graph.get_tensor_by_name('fc7:0')
#usethisifyouonlywanttochangegradientsofthelastlayer
fc7=tf.stop_gradient(fc7)#It'sanidentityfunction
fc7_shape=fc7.get_shape().as_list()
new_outputs=2
weights=tf.Variable(tf.truncated_normal([fc7_shape[3],num_outputs],stddev=0.05))
biases=tf.Variable(tf.constant(0.05,shape=[num_outputs]))
output=tf.matmul(fc7,weights)+biases
pred=tf.nn.softmax(output)
#Now,yourunthiswithfine-tuningdatainsess.run()
有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。