Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解
一、保存:
graph_util.convert_variables_to_constants可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量
参数2:GraphDef对象,它描述了计算网络
参数3:Graph图中需要输出的节点的名称的列表
返回值:精简版的GraphDef对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:
constant_graph=graph_util.convert_variables_to_constants(sess,sess.graph_def,['sum_operation']) withopen(pbName,mode='wb')asf: f.write(constant_graph.SerializeToString())
需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。
二、恢复:
恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:
graph0=tf.GraphDef() withopen(pbName,mode='rb')asf: graph0.ParseFromString(f.read()) tf.import_graph_def(graph0,name='')
三、代码:
importtensorflowastf
fromtensorflow.python.frameworkimportgraph_util
pbName='graphA.pb'
defgraphCreate():
withtf.Session()assess:
var1=tf.placeholder(tf.int32,name='var1')
var2=tf.Variable(20,name='var2')#实参name='var2'指定了操作名,该操作返回的张量名是在
#'var2'后面:0,即var2:0是返回的张量名,也就是说变量
#var2的名称是'var2:0'
var3=tf.Variable(30,name='var3')
var4=tf.Variable(40,name='var4')
var4op=tf.assign(var4,1000,name='var4op1')
sum=tf.Variable(4,name='sum')
sum=tf.add(var1,var2,name='var1_var2')
sum=tf.add(sum,var3,name='sum_var3')
sumOps=tf.add(sum,var4,name='sum_operation')
oper=tf.get_default_graph().get_operations()
withopen('operation.csv','wt')asf:
s='name,type,output\n'
f.write(s)
foroinoper:
s=o.name
s+=','+o.type
inp=o.inputs
oup=o.outputs
foriipininp:
s#s+=','+str(iip)
foriopinoup:
s+=','+str(iop)
s+='\n'
f.write(s)
forvarintf.global_variables():
print('variable=>',var.name)#张量是tf.Variable/tf.Add之类操作的结果,
#张量的名字使用操作名加:0来表示
init=tf.global_variables_initializer()
sess.run(init)
sess.run(var4op)
print('sum_operationresultisTensor',sess.run(sumOps,feed_dict={var1:1}))
constant_graph=graph_util.convert_variables_to_constants(sess,sess.graph_def,['sum_operation'])
withopen(pbName,mode='wb')asf:
f.write(constant_graph.SerializeToString())
defgraphGet():
print("startget:")
withtf.Graph().as_default():
graph0=tf.GraphDef()
withopen(pbName,mode='rb')asf:
graph0.ParseFromString(f.read())
tf.import_graph_def(graph0,name='')
withtf.Session()assess:
init=tf.global_variables_initializer()
sess.run(init)
v1=sess.graph.get_tensor_by_name('var1:0')
v2=sess.graph.get_tensor_by_name('var2:0')
v3=sess.graph.get_tensor_by_name('var3:0')
v4=sess.graph.get_tensor_by_name('var4:0')
sumTensor=sess.graph.get_tensor_by_name("sum_operation:0")
print('sumTensoris:',sumTensor)
print(sess.run(sumTensor,feed_dict={v1:1}))
graphCreate()
graphGet()
四、保存pb函数代码里的操作名称/类型/返回的张量:
以上这篇Tensorflow使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。