TensorFLow用Saver保存和恢复变量
本文为大家分享了TensorFLow用Saver保存和恢复变量的具体代码,供大家参考,具体内容如下
建立文件tensor_save.py,保存变量v1,v2的tensor到checkpointfiles中,名称分别设置为v3,v4。
importtensorflowastf #Createsomevariables. v1=tf.Variable(3,name="v1") v2=tf.Variable(4,name="v2") #Createmodel y=tf.add(v1,v2) #Addanoptoinitializethevariables. init_op=tf.initialize_all_variables() #Addopstosaveandrestoreallthevariables. saver=tf.train.Saver({'v3':v1,'v4':v2}) #Later,launchthemodel,initializethevariables,dosomework,savethe #variablestodisk. withtf.Session()assess: sess.run(init_op) print("v1=",v1.eval()) print("v2=",v2.eval()) #Savethevariablestodisk. save_path=saver.save(sess,"f:/tmp/model.ckpt") print("Modelsavedinfile:",save_path)
建立文件tensor_restror.py,将checkpointfiles中名称分别为v3,v4的tensor分别恢复到变量v3,v4中。
importtensorflowastf #Createsomevariables. v3=tf.Variable(0,name="v3") v4=tf.Variable(0,name="v4") #Createmodel y=tf.mul(v3,v4) #Addopstosaveandrestoreallthevariables. saver=tf.train.Saver() #Later,launchthemodel,usethesavertorestorevariablesfromdisk,and #dosomeworkwiththemodel. withtf.Session()assess: #Restorevariablesfromdisk. saver.restore(sess,"f:/tmp/model.ckpt") print("Modelrestored.") print("v3=",v3.eval()) print("v4=",v4.eval()) print("y=",sess.run(y))
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。