multivariate-random.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. #!/usr/bin/env python
  2. import matplotlib.pyplot as plt
  3. import numpy
  4. import csv
  5. def main(n):
  6. cov = [[25, 20], [20, 25]]
  7. meanI = [70, 40]
  8. datapointsI = n
  9. meanII = [60, 20]
  10. datapointsII = n
  11. dataI = numpy.random.multivariate_normal(meanI, cov, datapointsI).T
  12. x, y = dataI
  13. plt.plot(x, y, 'x')
  14. dataII = numpy.random.multivariate_normal(meanII, cov, datapointsII).T
  15. x, y = dataII
  16. plt.plot(x, y, 'x')
  17. plt.axis('equal')
  18. plt.show()
  19. data = []
  20. xs, ys = dataI
  21. for x, y in zip(xs, ys):
  22. data.append([x, y, 'a'])
  23. xs, ys = dataII
  24. for x, y in zip(xs, ys):
  25. data.append([x, y, 'b'])
  26. # Write data to csv files
  27. with open("data.csv", 'wb') as csvfile:
  28. csvfile.write("x,y,label\n")
  29. spamwriter = csv.writer(csvfile, delimiter=',',
  30. quotechar='"', quoting=csv.QUOTE_MINIMAL)
  31. for datapoint in data:
  32. spamwriter.writerow(datapoint)
  33. if __name__ == "__main__":
  34. from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
  35. parser = ArgumentParser(description=__doc__,
  36. formatter_class=ArgumentDefaultsHelpFormatter)
  37. parser.add_argument("-n",
  38. dest="n", default=2000, type=int,
  39. help="how many points should get generated")
  40. args = parser.parse_args()
  41. main(args.n)