linear_regression_eager_api.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. '''
  2. A logistic regression learning algorithm example using TensorFlow library.
  3. This example is using the MNIST database of handwritten digits
  4. (http://yann.lecun.com/exdb/mnist/)
  5. Author: Aymeric Damien
  6. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  7. '''
  8. from __future__ import absolute_import, division, print_function
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. import tensorflow as tf
  12. import tensorflow.contrib.eager as tfe
  13. # Set Eager API
  14. tfe.enable_eager_execution()
  15. # Training Data
  16. train_X = [3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167,
  17. 7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1]
  18. train_Y = [1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221,
  19. 2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3]
  20. n_samples = len(train_X)
  21. # Parameters
  22. learning_rate = 0.01
  23. display_step = 100
  24. num_steps = 1000
  25. # Weight and Bias
  26. W = tfe.Variable(np.random.randn())
  27. b = tfe.Variable(np.random.randn())
  28. # Linear regression (Wx + b)
  29. def linear_regression(inputs):
  30. return inputs * W + b
  31. # Mean square error
  32. def mean_square_fn(model_fn, inputs, labels):
  33. return tf.reduce_sum(tf.pow(model_fn(inputs) - labels, 2)) / (2 * n_samples)
  34. # SGD Optimizer
  35. optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
  36. # Compute gradients
  37. grad = tfe.implicit_gradients(mean_square_fn)
  38. # Initial cost, before optimizing
  39. print("Initial cost= {:.9f}".format(
  40. mean_square_fn(linear_regression, train_X, train_Y)),
  41. "W=", W.numpy(), "b=", b.numpy())
  42. # Training
  43. for step in range(num_steps):
  44. optimizer.apply_gradients(grad(linear_regression, train_X, train_Y))
  45. if (step + 1) % display_step == 0 or step == 0:
  46. print("Epoch:", '%04d' % (step + 1), "cost=",
  47. "{:.9f}".format(mean_square_fn(linear_regression, train_X, train_Y)),
  48. "W=", W.numpy(), "b=", b.numpy())
  49. # Graphic display
  50. plt.plot(train_X, train_Y, 'ro', label='Original data')
  51. plt.plot(train_X, np.array(W * train_X + b), label='Fitted line')
  52. plt.legend()
  53. plt.show()