在Tensorflow中实现梯度下降法更新参数值
我就废话不多说了,直接上代码吧!
tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
TensorFlow经过使用梯度下降法对损失函数中的变量进行修改值,默认修改tf.Variable(tf.zeros([784,10]))
为Variable的参数。
train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy,var_list=[w,b])
也可以使用var_list参数来定义更新那些参数的值
#导入Minst数据集 importinput_data mnist=input_data.read_data_sets("data",one_hot=True) #导入tensorflow库 importtensorflowastf #输入变量,把28*28的图片变成一维数组(丢失结构信息) x=tf.placeholder("float",[None,784]) #权重矩阵,把28*28=784的一维输入,变成0-9这10个数字的输出 w=tf.Variable(tf.zeros([784,10])) #偏置 b=tf.Variable(tf.zeros([10])) #核心运算,其实就是softmax(x*w+b) y=tf.nn.softmax(tf.matmul(x,w)+b) #这个是训练集的正确结果 y_=tf.placeholder("float",[None,10]) #交叉熵,作为损失函数 cross_entropy=-tf.reduce_sum(y_*tf.log(y)) #梯度下降算法,最小化交叉熵 train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #初始化,在run之前必须进行的 init=tf.initialize_all_variables() #创建session以便运算 sess=tf.Session() sess.run(init) #迭代1000次 foriinrange(1000): #获取训练数据集的图片输入和正确表示数字 batch_xs,batch_ys=mnist.train.next_batch(100) #运行刚才建立的梯度下降算法,x赋值为图片输入,y_赋值为正确的表示数字 sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys}) #tf.argmax获取最大值的索引。比较运算后的结果和本身结果是否相同。 #这步的结果应该是[1,1,1,1,1,1,1,1,0,1...........1,1,0,1]这种形式。 #1代表正确,0代表错误 correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) #tf.cast先将数据转换成float,防止求平均不准确。 #tf.reduce_mean由于只有一个参数,就是上面那个数组的平均值。 accuracy=tf.reduce_mean(tf.cast(correct_prediction,"float")) #输出 print(sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels}))
计算结果如下
"C:\ProgramFiles\Anaconda3\python.exe"D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py Extractingdata\train-images-idx3-ubyte.gz Extractingdata\train-labels-idx1-ubyte.gz Extractingdata\t10k-images-idx3-ubyte.gz Extractingdata\t10k-labels-idx1-ubyte.gz WARNING:tensorflow:FromC:\ProgramFiles\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175:initialize_all_variables(fromtensorflow.python.ops.variables)isdeprecatedandwillberemovedafter2017-03-02. Instructionsforupdating: Use`tf.global_variables_initializer`instead. 2018-05-1415:49:45.866600:WC:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45]TheTensorFlowlibrarywasn'tcompiledtouseAVXinstructions,buttheseareavailableonyourmachineandcouldspeedupCPUcomputations. 2018-05-1415:49:45.866600:WC:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45]TheTensorFlowlibrarywasn'tcompiledtouseAVX2instructions,buttheseareavailableonyourmachineandcouldspeedupCPUcomputations. 0.9163 Processfinishedwithexitcode0
如果限制,只更新参数W查看效果
"C:\ProgramFiles\Anaconda3\python.exe"D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py Extractingdata\train-images-idx3-ubyte.gz Extractingdata\train-labels-idx1-ubyte.gz Extractingdata\t10k-images-idx3-ubyte.gz Extractingdata\t10k-labels-idx1-ubyte.gz WARNING:tensorflow:FromC:\ProgramFiles\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175:initialize_all_variables(fromtensorflow.python.ops.variables)isdeprecatedandwillberemovedafter2017-03-02. Instructionsforupdating: Use`tf.global_variables_initializer`instead. 2018-05-1415:51:08.543600:WC:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45]TheTensorFlowlibrarywasn'tcompiledtouseAVXinstructions,buttheseareavailableonyourmachineandcouldspeedupCPUcomputations. 2018-05-1415:51:08.544600:WC:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45]TheTensorFlowlibrarywasn'tcompiledtouseAVX2instructions,buttheseareavailableonyourmachineandcouldspeedupCPUcomputations. 0.9187 Processfinishedwithexitcode0
可以看出只修改W对结果影响不大,如果设置只修改b
#导入Minst数据集 importinput_data mnist=input_data.read_data_sets("data",one_hot=True) #导入tensorflow库 importtensorflowastf #输入变量,把28*28的图片变成一维数组(丢失结构信息) x=tf.placeholder("float",[None,784]) #权重矩阵,把28*28=784的一维输入,变成0-9这10个数字的输出 w=tf.Variable(tf.zeros([784,10])) #偏置 b=tf.Variable(tf.zeros([10])) #核心运算,其实就是softmax(x*w+b) y=tf.nn.softmax(tf.matmul(x,w)+b) #这个是训练集的正确结果 y_=tf.placeholder("float",[None,10]) #交叉熵,作为损失函数 cross_entropy=-tf.reduce_sum(y_*tf.log(y)) #梯度下降算法,最小化交叉熵 train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy,var_list=[b]) #初始化,在run之前必须进行的 init=tf.initialize_all_variables() #创建session以便运算 sess=tf.Session() sess.run(init) #迭代1000次 foriinrange(1000): #获取训练数据集的图片输入和正确表示数字 batch_xs,batch_ys=mnist.train.next_batch(100) #运行刚才建立的梯度下降算法,x赋值为图片输入,y_赋值为正确的表示数字 sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys}) #tf.argmax获取最大值的索引。比较运算后的结果和本身结果是否相同。 #这步的结果应该是[1,1,1,1,1,1,1,1,0,1...........1,1,0,1]这种形式。 #1代表正确,0代表错误 correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) #tf.cast先将数据转换成float,防止求平均不准确。 #tf.reduce_mean由于只有一个参数,就是上面那个数组的平均值。 accuracy=tf.reduce_mean(tf.cast(correct_prediction,"float")) #输出 print(sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels}))
计算结果:
"C:\ProgramFiles\Anaconda3\python.exe"D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py Extractingdata\train-images-idx3-ubyte.gz Extractingdata\train-labels-idx1-ubyte.gz Extractingdata\t10k-images-idx3-ubyte.gz Extractingdata\t10k-labels-idx1-ubyte.gz WARNING:tensorflow:FromC:\ProgramFiles\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175:initialize_all_variables(fromtensorflow.python.ops.variables)isdeprecatedandwillberemovedafter2017-03-02. Instructionsforupdating: Use`tf.global_variables_initializer`instead. 2018-05-1415:52:04.483600:WC:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45]TheTensorFlowlibrarywasn'tcompiledtouseAVXinstructions,buttheseareavailableonyourmachineandcouldspeedupCPUcomputations. 2018-05-1415:52:04.483600:WC:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45]TheTensorFlowlibrarywasn'tcompiledtouseAVX2instructions,buttheseareavailableonyourmachineandcouldspeedupCPUcomputations. 0.1135 Processfinishedwithexitcode0
如果只更新b那么对效果影响很大。
以上这篇在Tensorflow中实现梯度下降法更新参数值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。