multivariate-random.py 909 B

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. #!/usr/bin/env python
  2. import matplotlib.pyplot as plt
  3. import numpy
  4. import csv
  5. cov = [[25, 20], [20, 25]] # diagonal covariance, points lie on x or y-axis
  6. meanI = [70, 40]
  7. datapointsI = 2000
  8. meanII = [60, 20]
  9. datapointsII = 2000
  10. dataI = numpy.random.multivariate_normal(meanI, cov, datapointsI).T
  11. x, y = dataI
  12. plt.plot(x, y, 'x')
  13. dataII = numpy.random.multivariate_normal(meanII, cov, datapointsII).T
  14. x, y = dataII
  15. plt.plot(x, y, 'x')
  16. plt.axis('equal')
  17. plt.show()
  18. data = []
  19. xs, ys = dataI
  20. for x, y in zip(xs, ys):
  21. data.append([x, y, 'a'])
  22. xs, ys = dataII
  23. for x, y in zip(xs, ys):
  24. data.append([x, y, 'b'])
  25. # Write data to csv files
  26. with open("data.csv", 'wb') as csvfile:
  27. csvfile.write("x,y,label\n")
  28. spamwriter = csv.writer(csvfile, delimiter=',',
  29. quotechar='"', quoting=csv.QUOTE_MINIMAL)
  30. for datapoint in data:
  31. spamwriter.writerow(datapoint)