tensorflow estimator 使用hook实现finetune方式
为了实现finetune有如下两种解决方案:
model_fn里面定义好模型之后直接赋值
defmodel_fn(features,labels,mode,params): #..... #finetune ifparams.checkpoint_pathand(nottf.train.latest_checkpoint(params.model_dir)): checkpoint_path=None iftf.gfile.IsDirectory(params.checkpoint_path): checkpoint_path=tf.train.latest_checkpoint(params.checkpoint_path) else: checkpoint_path=params.checkpoint_path tf.train.init_from_checkpoint( ckpt_dir_or_file=checkpoint_path, assignment_map={params.checkpoint_scope:params.checkpoint_scope}#'OptimizeLoss/':'OptimizeLoss/' )
使用钩子hooks。
可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定
#Definetheexperiment experiment=tf.contrib.learn.Experiment( estimator=estimator,#Estimator train_input_fn=train_input_fn,#First-classfunction eval_input_fn=eval_input_fn,#First-classfunction train_steps=params.train_steps,#Minibatchsteps min_eval_frequency=params.eval_min_frequency,#Evalfrequency #train_monitors=[],#Hooksfortraining #eval_hooks=[eval_input_hook],#Hooksforevaluation eval_steps=params.eval_steps#Useevaluationfeederuntilitsempty )
也可以在定义tf.estimator.EstimatorSpec的时候通过training_chief_hooks参数指定。
不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。
defmodel_fn(features,labels,mode,params): #.... returntf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, #scaffold=get_scaffold(), #training_chief_hooks=None )
这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:
当前的模式:
mode:AModeKeys.Specifiesifthisistraining,evaluationorprediction.
计算图
predictions:PredictionsTensorordictofTensor.
loss:TraininglossTensor.Mustbeeitherscalar,orwithshape[1].
train_op:Opforthetrainingstep.
eval_metric_ops:Dictofmetricresultskeyedbyname.Thevaluesofthedictaretheresultsofcallingametricfunction,namelya(metric_tensor,update_op)tuple.metric_tensorshouldbeevaluatedwithoutanyimpactonstate(typicallyisapurecomputationresultsbasedonvariables.).Forexample,itshouldnottriggertheupdate_oporrequiresanyinputfetching.
导出策略
export_outputs:Describestheoutputsignaturestobeexportedto
SavedModelandusedduringserving.Adict{name:output}where:
name:Anarbitrarynameforthisoutput.
output:anExportOutputobjectsuchasClassificationOutput,RegressionOutput,orPredictOutput.Single-headedmodelsonlyneedtospecifyoneentryinthisdictionary.Multi-headedmodelsshouldspecifyoneentryforeachhead,oneofwhichmustbenamedusingsignature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.
chief钩子训练时的模型保存策略钩子CheckpointSaverHook,模型恢复等
training_chief_hooks:Iterableoftf.train.SessionRunHookobjectstorunonthechiefworkerduringtraining.
worker钩子训练时的监控策略钩子如:NanTensorHookLoggingTensorHook等
training_hooks:Iterableoftf.train.SessionRunHookobjectstorunonallworkersduringtraining.
指定初始化和saver
scaffold:Atf.train.Scaffoldobjectthatcanbeusedtosetinitialization,saver,andmoretobeusedintraining.
evaluation钩子
evaluation_hooks:Iterableoftf.train.SessionRunHookobjectstorunduringevaluation.
自定义的钩子如下:
classRestoreCheckpointHook(tf.train.SessionRunHook): def__init__(self, checkpoint_path, exclude_scope_patterns, include_scope_patterns ): tf.logging.info("CreateRestoreCheckpointHook.") #super(IteratorInitializerHook,self).__init__() self.checkpoint_path=checkpoint_path self.exclude_scope_patterns=Noneif(notexclude_scope_patterns)elseexclude_scope_patterns.split(',') self.include_scope_patterns=Noneif(notinclude_scope_patterns)elseinclude_scope_patterns.split(',') defbegin(self): #Youcanaddopstothegraphhere. print('Beforestartingthesession.') #1.Createsaver #exclusions=[] #ifself.checkpoint_exclude_scopes: #exclusions=[scope.strip() #forscopeinself.checkpoint_exclude_scopes.split(',')] # #variables_to_restore=[] #forvarinslim.get_model_variables():#tf.global_variables(): #excluded=False #forexclusioninexclusions: #ifvar.op.name.startswith(exclusion): #excluded=True #break #ifnotexcluded: #variables_to_restore.append(var) #inclusions #[varforvarintf.trainable_variables()ifvar.op.name.startswith('InceptionResnetV1')] variables_to_restore=tf.contrib.framework.filter_variables( slim.get_model_variables(), include_patterns=self.include_scope_patterns,#['Conv'], exclude_patterns=self.exclude_scope_patterns,#['biases','Logits'], #IfTrue(default),performsre.searchtofindmatches #(i.e.patterncanmatchanysubstringofthevariablename). #IfFalse,performsre.match(i.e.regexpshouldmatchfromthebeginningofthevariablename). reg_search=True ) self.saver=tf.train.Saver(variables_to_restore) defafter_create_session(self,session,coord): #Whenthisiscalled,thegraphisfinalizedand #opscannolongerbeaddedtothegraph. print('Sessioncreated.') tf.logging.info('Fine-tuningfrom%s'%self.checkpoint_path) self.saver.restore(session,os.path.expanduser(self.checkpoint_path)) tf.logging.info('Endfineturnfrom%s'%self.checkpoint_path) defbefore_run(self,run_context): #print('Beforecallingsession.run().') returnNone#SessionRunArgs(self.your_tensor) defafter_run(self,run_context,run_values): #print('Donerunningonestep.Thevalueofmytensor:%s',run_values.results) #ifyou-need-to-stop-loop: #run_context.request_stop() pass defend(self,session): #print('Donewiththesession.') pass
以上这篇tensorflowestimator使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。