tensorflow estimator 使用hook实现finetune方式
为了实现finetune有如下两种解决方案:
model_fn里面定义好模型之后直接赋值
defmodel_fn(features,labels,mode,params):
#.....
#finetune
ifparams.checkpoint_pathand(nottf.train.latest_checkpoint(params.model_dir)):
checkpoint_path=None
iftf.gfile.IsDirectory(params.checkpoint_path):
checkpoint_path=tf.train.latest_checkpoint(params.checkpoint_path)
else:
checkpoint_path=params.checkpoint_path
tf.train.init_from_checkpoint(
ckpt_dir_or_file=checkpoint_path,
assignment_map={params.checkpoint_scope:params.checkpoint_scope}#'OptimizeLoss/':'OptimizeLoss/'
)
使用钩子hooks。
可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定
#Definetheexperiment experiment=tf.contrib.learn.Experiment( estimator=estimator,#Estimator train_input_fn=train_input_fn,#First-classfunction eval_input_fn=eval_input_fn,#First-classfunction train_steps=params.train_steps,#Minibatchsteps min_eval_frequency=params.eval_min_frequency,#Evalfrequency #train_monitors=[],#Hooksfortraining #eval_hooks=[eval_input_hook],#Hooksforevaluation eval_steps=params.eval_steps#Useevaluationfeederuntilitsempty )
也可以在定义tf.estimator.EstimatorSpec的时候通过training_chief_hooks参数指定。
不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。
defmodel_fn(features,labels,mode,params): #.... returntf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, #scaffold=get_scaffold(), #training_chief_hooks=None )
这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:
当前的模式:
mode:AModeKeys.Specifiesifthisistraining,evaluationorprediction.
计算图
predictions:PredictionsTensorordictofTensor.
loss:TraininglossTensor.Mustbeeitherscalar,orwithshape[1].
train_op:Opforthetrainingstep.
eval_metric_ops:Dictofmetricresultskeyedbyname.Thevaluesofthedictaretheresultsofcallingametricfunction,namelya(metric_tensor,update_op)tuple.metric_tensorshouldbeevaluatedwithoutanyimpactonstate(typicallyisapurecomputationresultsbasedonvariables.).Forexample,itshouldnottriggertheupdate_oporrequiresanyinputfetching.
导出策略
export_outputs:Describestheoutputsignaturestobeexportedto
SavedModelandusedduringserving.Adict{name:output}where:
name:Anarbitrarynameforthisoutput.
output:anExportOutputobjectsuchasClassificationOutput,RegressionOutput,orPredictOutput.Single-headedmodelsonlyneedtospecifyoneentryinthisdictionary.Multi-headedmodelsshouldspecifyoneentryforeachhead,oneofwhichmustbenamedusingsignature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.
chief钩子训练时的模型保存策略钩子CheckpointSaverHook,模型恢复等
training_chief_hooks:Iterableoftf.train.SessionRunHookobjectstorunonthechiefworkerduringtraining.
worker钩子训练时的监控策略钩子如:NanTensorHookLoggingTensorHook等
training_hooks:Iterableoftf.train.SessionRunHookobjectstorunonallworkersduringtraining.
指定初始化和saver
scaffold:Atf.train.Scaffoldobjectthatcanbeusedtosetinitialization,saver,andmoretobeusedintraining.
evaluation钩子
evaluation_hooks:Iterableoftf.train.SessionRunHookobjectstorunduringevaluation.
自定义的钩子如下:
classRestoreCheckpointHook(tf.train.SessionRunHook):
def__init__(self,
checkpoint_path,
exclude_scope_patterns,
include_scope_patterns
):
tf.logging.info("CreateRestoreCheckpointHook.")
#super(IteratorInitializerHook,self).__init__()
self.checkpoint_path=checkpoint_path
self.exclude_scope_patterns=Noneif(notexclude_scope_patterns)elseexclude_scope_patterns.split(',')
self.include_scope_patterns=Noneif(notinclude_scope_patterns)elseinclude_scope_patterns.split(',')
defbegin(self):
#Youcanaddopstothegraphhere.
print('Beforestartingthesession.')
#1.Createsaver
#exclusions=[]
#ifself.checkpoint_exclude_scopes:
#exclusions=[scope.strip()
#forscopeinself.checkpoint_exclude_scopes.split(',')]
#
#variables_to_restore=[]
#forvarinslim.get_model_variables():#tf.global_variables():
#excluded=False
#forexclusioninexclusions:
#ifvar.op.name.startswith(exclusion):
#excluded=True
#break
#ifnotexcluded:
#variables_to_restore.append(var)
#inclusions
#[varforvarintf.trainable_variables()ifvar.op.name.startswith('InceptionResnetV1')]
variables_to_restore=tf.contrib.framework.filter_variables(
slim.get_model_variables(),
include_patterns=self.include_scope_patterns,#['Conv'],
exclude_patterns=self.exclude_scope_patterns,#['biases','Logits'],
#IfTrue(default),performsre.searchtofindmatches
#(i.e.patterncanmatchanysubstringofthevariablename).
#IfFalse,performsre.match(i.e.regexpshouldmatchfromthebeginningofthevariablename).
reg_search=True
)
self.saver=tf.train.Saver(variables_to_restore)
defafter_create_session(self,session,coord):
#Whenthisiscalled,thegraphisfinalizedand
#opscannolongerbeaddedtothegraph.
print('Sessioncreated.')
tf.logging.info('Fine-tuningfrom%s'%self.checkpoint_path)
self.saver.restore(session,os.path.expanduser(self.checkpoint_path))
tf.logging.info('Endfineturnfrom%s'%self.checkpoint_path)
defbefore_run(self,run_context):
#print('Beforecallingsession.run().')
returnNone#SessionRunArgs(self.your_tensor)
defafter_run(self,run_context,run_values):
#print('Donerunningonestep.Thevalueofmytensor:%s',run_values.results)
#ifyou-need-to-stop-loop:
#run_context.request_stop()
pass
defend(self,session):
#print('Donewiththesession.')
pass
以上这篇tensorflowestimator使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。