TensorFlow教程04:MNIST实验——源码和运行结果

来源:互联网 发布:java心得体会 编辑:程序博客网 时间:2024/09/21 11:23

假定您已经安装好了TensorFlow,这里放了第一个MNIST实验的代码和参考结果,你可以直接运行验证。

源码

[python] view plain copy
print?
  1. #!/usr/bin/python  
  2. import tensorflow as tf  
  3. import sys  
  4. from tensorflow.examples.tutorials.mnist import input_data  
  5.   
  6. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)  
  7.   
  8. x = tf.placeholder("float", [None784])  
  9. W = tf.Variable(tf.zeros([784,10]))  
  10. b = tf.Variable(tf.zeros([10]))  
  11.   
  12. y = tf.nn.softmax(tf.matmul(x,W) + b)  
  13.   
  14. y_ = tf.placeholder("float", [None,10])  
  15. cross_entropy = -tf.reduce_sum(y_*tf.log(y))  
  16.   
  17. train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  
  18.   
  19. init = tf.global_variables_initializer()  
  20. sess = tf.Session()  
  21. sess.run(init)  
  22.   
  23. for i in range(1000):  
  24.   if i % 20 == 0:  
  25.     sys.stdout.write('.')  
  26.   batch_xs, batch_ys = mnist.train.next_batch(100)  
  27.   sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})  
  28. print ""  
  29.   
  30. correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))  
  31. accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  
  32.   
  33. print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})  

运行结果

[plain] view plain copy
print?
  1. Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.  
  2. Extracting MNIST_data/train-images-idx3-ubyte.gz  
  3. Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.  
  4. Extracting MNIST_data/train-labels-idx1-ubyte.gz  
  5. Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.  
  6. Extracting MNIST_data/t10k-images-idx3-ubyte.gz  
  7. Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.  
  8. Extracting MNIST_data/t10k-labels-idx1-ubyte.gz  
  9. ..................................................  
  10. 0.9177  


阅读全文
0 0
原创粉丝点击