tensorflow 固定部分参数训练,只训练部分参数的实例
在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练。
1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如
definference(input_): """ thisiswhereyouputyourgraph. thefollowingisjustanexample. """ conv1=tf.layers.conv2d(input_) conv2=tf.layers.conv2d(conv1) returnconv2 input_=tf.placeholder() output=inference(input_) ... calculate_loss_op=... train_op=... ... withtf.Session()assess: sess.run([loss,train_op],feed_dict={input_:train_data}) ifvalidation==True: sess.run([loss],feed_dict={input_:validate_date})
这种方式很简单,也很直接了然。
2.但是,如果处理的数据量很大的时候,使用tf.placeholder()来载入数据会严重地拖慢训练的进度,因此,常用tfrecords文件来读取数据。
此时,很容易想到,将不同的值传入inference()函数中进行计算。
train_batch,label_batch=decode_train() val_train_batch,val_label_batch=decode_validation() train_result=inference(train_batch) ... loss=.. train_op=... ... ifvalidation==True: val_result=inference(val_train_batch) val_loss=.. withtf.Session()assess: sess.run([loss,train_op]) ifvalidation==True: sess.run([val_result,val_loss])
这种方式看似能够直接调用inference()来对验证数据进行前向传播计算,但是,实则会在原图上添加上许多新的结点,这些结点的参数都是需要重新初始化的,也是就是说,验证的时候并不是使用训练的权重。
3.用一个tf.placeholder来控制是否训练、验证。
definference(input_): ... ... ... returninference_result train_batch,label_batch=decode_train() val_batch,val_label=decode_validation() is_training=tf.placeholder(tf.bool,shape=()) x=tf.cond(is_training,lambda:train_batch,lambda:val_batch) y=tf.cond(is_training,lambda:train_label,lambda:val_label) logits=inference(x) loss=cal_loss(logits,y) train_op=optimize(loss) withtf.Session()assess: loss,_=sess.run([loss,train_op],feed_dict={is_training:True}) ifvalidation==True: loss=sess.run(loss,feed_dict={is_training:False})
使用这种方式就可以在一个大图里创建一个分支条件,从而通过控制placeholder来控制是否进行验证。
以上这篇tensorflow固定部分参数训练,只训练部分参数的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。