Kaynağa Gözat

added testing as requested by issue #2

aymericdamien 9 yıl önce
ebeveyn
işleme
104c4de9f3

+ 14 - 1
examples/2 - Basic Classifiers/linear_regression.py

@@ -55,10 +55,23 @@ with tf.Session() as sess:
                 "W=", sess.run(W), "b=", sess.run(b)
 
     print "Optimization Finished!"
-    print "cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b)
+    training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
+    print "Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n'
+
+
+    # Testing example, as requested (Issue #2)
+    test_X = numpy.asarray([6.83,4.668,8.9,7.91,5.7,8.7,3.1,2.1])
+    test_Y = numpy.asarray([1.84,2.273,3.2,2.831,2.92,3.24,1.35,1.03])
+
+    print "Testing... (L2 loss Comparison)"
+    testing_cost = sess.run(tf.reduce_sum(tf.pow(activation-Y, 2))/(2*test_X.shape[0]),
+                            feed_dict={X: test_X, Y: test_Y}) #same function as cost above
+    print "Testing cost=", testing_cost
+    print "Absolute l2 loss difference:", abs(training_cost - testing_cost)
 
     #Graphic display
     plt.plot(train_X, train_Y, 'ro', label='Original data')
+    plt.plot(test_X, test_Y, 'bo', label='Testing data')
     plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
     plt.legend()
     plt.show()