Tensorflow训练MNIST手写数字识别模型
本文实例为大家分享了Tensorflow训练MNIST手写数字识别模型的具体代码,供大家参考,具体内容如下
importtensorflowastf fromtensorflow.examples.tutorials.mnistimportinput_data INPUT_NODE=784#输入层节点=图片像素=28x28=784 OUTPUT_NODE=10#输出层节点数=图片类别数目 LAYER1_NODE=500#隐藏层节点数,只有一个隐藏层 BATCH_SIZE=100#一个训练包中的数据个数,数字越小 #越接近随机梯度下降,越大越接近梯度下降 LEARNING_RATE_BASE=0.8#基础学习率 LEARNING_RATE_DECAY=0.99#学习率衰减率 REGULARIZATION_RATE=0.0001#正则化项系数 TRAINING_STEPS=30000#训练轮数 MOVING_AVG_DECAY=0.99#滑动平均衰减率 #定义一个辅助函数,给定神经网络的输入和所有参数,计算神经网络的前向传播结果 definference(input_tensor,avg_class,weights1,biases1, weights2,biases2): #当没有提供滑动平均类时,直接使用参数当前取值 ifavg_class==None: #计算隐藏层前向传播结果 layer1=tf.nn.relu(tf.matmul(input_tensor,weights1)+biases1) #计算输出层前向传播结果 returntf.matmul(layer1,weights2)+biases2 else: #首先计算变量的滑动平均值,然后计算前向传播结果 layer1=tf.nn.relu( tf.matmul(input_tensor,avg_class.average(weights1))+ avg_class.average(biases1)) returntf.matmul( layer1,avg_class.average(weights2))+avg_class.average(biases2) #训练模型的过程 deftrain(mnist): x=tf.placeholder(tf.float32,[None,INPUT_NODE],name='x-input') y_=tf.placeholder(tf.float32,[None,OUTPUT_NODE],name='y-input') #生成隐藏层参数 weights1=tf.Variable( tf.truncated_normal([INPUT_NODE,LAYER1_NODE],stddev=0.1)) biases1=tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE])) #生成输出层参数 weights2=tf.Variable( tf.truncated_normal([LAYER1_NODE,OUTPUT_NODE],stddev=0.1)) biases2=tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE])) #计算前向传播结果,不使用参数滑动平均值avg_class=None y=inference(x,None,weights1,biases1,weights2,biases2) #定义训练轮数变量,指定为不可训练 global_step=tf.Variable(0,trainable=False) #给定滑动平均衰减率和训练轮数的变量,初始化滑动平均类 variable_avgs=tf.train.ExponentialMovingAverage( MOVING_AVG_DECAY,global_step) #在所有代表神经网络参数的可训练变量上使用滑动平均 variables_avgs_op=variable_avgs.apply(tf.trainable_variables()) #计算使用滑动平均值后的前向传播结果 avg_y=inference(x,variable_avgs,weights1,biases1,weights2,biases2) #计算交叉熵作为损失函数 cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits( logits=y,labels=tf.argmax(y_,1)) cross_entropy_mean=tf.reduce_mean(cross_entropy) #计算L2正则化损失函数 regularizer=tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) regularization=regularizer(weights1)+regularizer(weights2) loss=cross_entropy_mean+regularization #设置指数衰减的学习率 learning_rate=tf.train.exponential_decay( LEARNING_RATE_BASE, global_step,#当前迭代轮数 mnist.train.num_examples/BATCH_SIZE,#过完所有训练数据的迭代次数 LEARNING_RATE_DECAY) #优化损失函数 train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize( loss,global_step=global_step) #反向传播同时更新神经网络参数及其滑动平均值 withtf.control_dependencies([train_step,variables_avgs_op]): train_op=tf.no_op(name='train') #检验使用了滑动平均模型的神经网络前向传播结果是否正确 correct_prediction=tf.equal(tf.argmax(avg_y,1),tf.argmax(y_,1)) accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #初始化会话并开始训练 withtf.Session()assess: tf.global_variables_initializer().run() #准备验证数据,用于判断停止条件和训练效果 validate_feed={x:mnist.validation.images, y_:mnist.validation.labels} #准备测试数据,用于模型优劣的最后评价标准 test_feed={x:mnist.test.images,y_:mnist.test.labels} #迭代训练神经网络 foriinrange(TRAINING_STEPS): ifi%1000==0: validate_acc=sess.run(accuracy,feed_dict=validate_feed) print("After%dtrainingstep(s),validationaccuracyusingaverage" "modelis%g"%(i,validate_acc)) xs,ys=mnist.train.next_batch(BATCH_SIZE) sess.run(train_op,feed_dict={x:xs,y_:ys}) #训练结束后在测试集上检测模型的最终正确率 test_acc=sess.run(accuracy,feed_dict=test_feed) print("After%dtrainingsteps,testaccuracyusingaveragemodel" "is%g"%(TRAINING_STEPS,test_acc)) #主程序入口 defmain(argv=None): mnist=input_data.read_data_sets("/tmp/data",one_hot=True) train(mnist) #Tensorflow主程序入口 if__name__=='__main__': tf.app.run()
输出结果如下:
Extracting/tmp/data/train-images-idx3-ubyte.gz Extracting/tmp/data/train-labels-idx1-ubyte.gz Extracting/tmp/data/t10k-images-idx3-ubyte.gz Extracting/tmp/data/t10k-labels-idx1-ubyte.gz After0trainingstep(s),validationaccuracyusingaveragemodelis0.0462 After1000trainingstep(s),validationaccuracyusingaveragemodelis0.9784 After2000trainingstep(s),validationaccuracyusingaveragemodelis0.9806 After3000trainingstep(s),validationaccuracyusingaveragemodelis0.9798 After4000trainingstep(s),validationaccuracyusingaveragemodelis0.9814 After5000trainingstep(s),validationaccuracyusingaveragemodelis0.9826 After6000trainingstep(s),validationaccuracyusingaveragemodelis0.9828 After7000trainingstep(s),validationaccuracyusingaveragemodelis0.9832 After8000trainingstep(s),validationaccuracyusingaveragemodelis0.9838 After9000trainingstep(s),validationaccuracyusingaveragemodelis0.983 After10000trainingstep(s),validationaccuracyusingaveragemodelis0.9836 After11000trainingstep(s),validationaccuracyusingaveragemodelis0.9822 After12000trainingstep(s),validationaccuracyusingaveragemodelis0.983 After13000trainingstep(s),validationaccuracyusingaveragemodelis0.983 After14000trainingstep(s),validationaccuracyusingaveragemodelis0.9844 After15000trainingstep(s),validationaccuracyusingaveragemodelis0.9832 After16000trainingstep(s),validationaccuracyusingaveragemodelis0.9844 After17000trainingstep(s),validationaccuracyusingaveragemodelis0.9842 After18000trainingstep(s),validationaccuracyusingaveragemodelis0.9842 After19000trainingstep(s),validationaccuracyusingaveragemodelis0.9838 After20000trainingstep(s),validationaccuracyusingaveragemodelis0.9834 After21000trainingstep(s),validationaccuracyusingaveragemodelis0.9828 After22000trainingstep(s),validationaccuracyusingaveragemodelis0.9834 After23000trainingstep(s),validationaccuracyusingaveragemodelis0.9844 After24000trainingstep(s),validationaccuracyusingaveragemodelis0.9838 After25000trainingstep(s),validationaccuracyusingaveragemodelis0.9834 After26000trainingstep(s),validationaccuracyusingaveragemodelis0.984 After27000trainingstep(s),validationaccuracyusingaveragemodelis0.984 After28000trainingstep(s),validationaccuracyusingaveragemodelis0.9836 After29000trainingstep(s),validationaccuracyusingaveragemodelis0.9842 After30000trainingsteps,testaccuracyusingaveragemodelis0.9839
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。