tensorflow pb to tflite 精度下降详解
之前希望在手机端使用深度模型做OCR,于是尝试在手机端部署tensorflow模型,用于图像分类。
思路主要是想使用tflite部署到安卓端,但是在使用tflite的时候发现模型的精度大幅度下降,已经不能支持业务需求了,最后就把OCR模型调用写在服务端了,但是精度下降的原因目前也没有找到,现在这里记录一下。
工作思路:
1.训练图像分类模型;2.模型固化成pb;3.由pb转成tflite文件;
但是使用python的tfinterpreter调用tflite文件就已经出现精度下降的问题,android端部署也是一样。
1.网络结构
from__future__importabsolute_import from__future__importdivision from__future__importprint_function importtensorflowastf slim=tf.contrib.slim defttnet(images,num_classes=10,is_training=False, dropout_keep_prob=0.5, prediction_fn=slim.softmax, scope='TtNet'): end_points={} withtf.variable_scope(scope,'TtNet',[images,num_classes]): net=slim.conv2d(images,32,[3,3],scope='conv1') #net=slim.conv2d(images,64,[3,3],scope='conv1_2') net=slim.max_pool2d(net,[2,2],2,scope='pool1') net=slim.batch_norm(net,activation_fn=tf.nn.relu,scope='bn1') #net=slim.conv2d(net,128,[3,3],scope='conv2_1') net=slim.conv2d(net,64,[3,3],scope='conv2') net=slim.max_pool2d(net,[2,2],2,scope='pool2') net=slim.conv2d(net,128,[3,3],scope='conv3') net=slim.max_pool2d(net,[2,2],2,scope='pool3') net=slim.conv2d(net,256,[3,3],scope='conv4') net=slim.max_pool2d(net,[2,2],2,scope='pool4') net=slim.batch_norm(net,activation_fn=tf.nn.relu,scope='bn2') #net=slim.conv2d(net,512,[3,3],scope='conv5') #net=slim.max_pool2d(net,[2,2],2,scope='pool5') net=slim.flatten(net) end_points['Flatten']=net #net=slim.fully_connected(net,1024,scope='fc3') net=slim.dropout(net,dropout_keep_prob,is_training=is_training, scope='dropout3') logits=slim.fully_connected(net,num_classes,activation_fn=None, scope='fc4') end_points['Logits']=logits end_points['Predictions']=prediction_fn(logits,scope='Predictions') returnlogits,end_points ttnet.default_image_size=28 defttnet_arg_scope(weight_decay=0.0): withslim.arg_scope( [slim.conv2d,slim.fully_connected], weights_regularizer=slim.l2_regularizer(weight_decay), weights_initializer=tf.truncated_normal_initializer(stddev=0.1), activation_fn=tf.nn.relu)assc: returnsc
基于slim,由于是一个比较简单的分类问题,网络结构也很简单,几个卷积加池化。
测试效果是很棒的。真实样本测试集能达到99%+的准确率。
2.模型固化,生成pb文件
#coding:utf-8 from__future__importabsolute_import from__future__importdivision from__future__importprint_function importtensorflowastf fromnetsimportnets_factory importcv2 importos importnumpyasnp fromdatasetsimportdataset_factory frompreprocessingimportpreprocessing_factory fromtensorflow.python.platformimportgfile slim=tf.contrib.slim #todo #supportarbitrayimagesizeandnum_class tf.app.flags.DEFINE_string( 'checkpoint_path','/tmp/tfmodel/', 'Thedirectorywherethemodelwaswrittentooranabsolutepathtoa' 'checkpointfile.') tf.app.flags.DEFINE_string( 'model_name','inception_v3','Thenameofthearchitecturetoevaluate.') tf.app.flags.DEFINE_string( 'preprocessing_name',None,'Thenameofthepreprocessingtouse.Ifleft' 'as`None`,thenthemodel_nameflagisused.') FLAGS=tf.app.flags.FLAGS tf.app.flags.DEFINE_integer( 'eval_image_size',None,'Evalimagesize') tf.app.flags.DEFINE_integer( 'eval_image_height',None,'Evalimageheight') tf.app.flags.DEFINE_integer( 'eval_image_width',None,'Evalimagewidth') tf.app.flags.DEFINE_string( 'export_path','./ttnet_1.0_37_32.pb','theexportpathofthepdfile') FLAGS=tf.app.flags.FLAGS NUM_CLASSES=37 defmain(_): network_fn=nets_factory.get_network_fn( FLAGS.model_name, num_classes=NUM_CLASSES, is_training=False) #pre_image=tf.placeholder(tf.float32,[None,None,3],name='input_data') #preprocessing_name=FLAGS.preprocessing_nameorFLAGS.model_name #image_preprocessing_fn=preprocessing_factory.get_preprocessing( #preprocessing_name, #is_training=False) #image=image_preprocessing_fn(pre_image,FLAGS.eval_image_height,FLAGS.eval_image_width) #images2=tf.expand_dims(image,0) images2=tf.placeholder(tf.float32,(None,32,32,3),name='input_data') logits,endpoints=network_fn(images2) withtf.Session()assess: output=tf.identity(endpoints['Predictions'],name="output_data") withgfile.GFile(FLAGS.export_path,'wb')asf: f.write(sess.graph_def.SerializeToString()) if__name__=='__main__': tf.app.run()
3.生成tflite文件
importtensorflowastf graph_def_file="/datastore1/Colonist_Lord/Colonist_Lord/workspace/models/model_files/passport_model_with_tflite/ocr_frozen.pb" input_arrays=["input_data"] output_arrays=["output_data"] converter=tf.lite.TFLiteConverter.from_frozen_graph( graph_def_file,input_arrays,output_arrays) tflite_model=converter.convert() open("converted_model.tflite","wb").write(tflite_model)
使用pb文件进行测试,效果正常;使用tflite文件进行测试,精度下降严重。下面附上pb与tflite测试代码。
pb测试代码
withtf.gfile.GFile(graph_filename,"rb")asf: graph_def=tf.GraphDef() graph_def.ParseFromString(f.read()) withtf.Graph().as_default()asgraph: tf.import_graph_def(graph_def) input_node=graph.get_tensor_by_name('import/input_data:0') output_node=graph.get_tensor_by_name('import/output_data:0') withtf.Session()assess: forimage_fileinimage_files: abs_path=os.path.join(image_folder,image_file) img=cv2.imread(abs_path).astype(np.float32) img=cv2.resize(img,(int(input_node.shape[1]),int(input_node.shape[2]))) output_data=sess.run(output_node,feed_dict={input_node:[img]}) index=np.argmax(output_data) label=dict_laebl[index] dst_floder=os.path.join(result_folder,label) ifnotos.path.exists(dst_floder): os.mkdir(dst_floder) cv2.imwrite(os.path.join(dst_floder,image_file),img) count+=1
tflite测试代码
model_path="converted_model.tflite"#"/datastore1/Colonist_Lord/Colonist_Lord/data/passport_char/ocr.tflite" interpreter=tf.contrib.lite.Interpreter(model_path=model_path) interpreter.allocate_tensors() #Getinputandoutputtensors. input_details=interpreter.get_input_details() output_details=interpreter.get_output_details() forimage_fileinimage_files: abs_path=os.path.join(image_folder,image_file) img=cv2.imread(abs_path).astype(np.float32) img=cv2.resize(img,tuple(input_details[0]['shape'][1:3])) #input_data=np.array(np.random.random_sample(input_shape),dtype=np.float32) interpreter.set_tensor(input_details[0]['index'],[img]) interpreter.invoke() output_data=interpreter.get_tensor(output_details[0]['index']) index=np.argmax(output_data) label=dict_laebl[index] dst_floder=os.path.join(result_folder,label) ifnotos.path.exists(dst_floder): os.mkdir(dst_floder) cv2.imwrite(os.path.join(dst_floder,image_file),img) count+=1
最后也算是绕过这个问题解决了业务需求,后面有空的话,还是会花时间研究一下这个问题。
如果有哪个大佬知道原因,希望不吝赐教。
补充知识:.pb转tflite代码,使用量化,减小体积,converter.post_training_quantize=True
importtensorflowastf path="/home/python/Downloads/a.pb"#pb文件位置和文件名 inputs=["input_images"]#模型文件的输入节点名称 classes=['feature_fusion/Conv_7/Sigmoid','feature_fusion/concat_3']#模型文件的输出节点名称 #converter=tf.contrib.lite.TocoConverter.from_frozen_graph(path,inputs,classes,input_shapes={'input_images':[1,320,320,3]}) converter=tf.lite.TFLiteConverter.from_frozen_graph(path,inputs,classes, input_shapes={'input_images':[1,320,320,3]}) converter.post_training_quantize=True tflite_model=converter.convert() open("/home/python/Downloads/aNew.tflite","wb").write(tflite_model)
以上这篇tensorflowpbtotflite精度下降详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。