plot_clustering.py 1.0 KB

123456789101112131415161718192021222324252627282930313233
  1. import matplotlib.pyplot as plt
  2. def plot_kmeans_clustering_results(c1, c2, c3, vq1, vq2, vq3):
  3. # Setting plot limits
  4. x1, x2 = -10, 10
  5. y1, y2 = -10, 10
  6. fig = plt.figure()
  7. fig.subplots_adjust(hspace=0.1, wspace=0.1)
  8. ax1 = fig.add_subplot(121, aspect='equal')
  9. ax1.scatter(c1[:, 0], c1[:, 1], lw=0.5, color='#00CC00')
  10. ax1.scatter(c2[:, 0], c2[:, 1], lw=0.5, color='#028E9B')
  11. ax1.scatter(c3[:, 0], c3[:, 1], lw=0.5, color='#FF7800')
  12. ax1.xaxis.set_visible(False)
  13. ax1.yaxis.set_visible(False)
  14. ax1.set_xlim(x1, x2)
  15. ax1.set_ylim(y1, y2)
  16. ax1.text(-9, 8, 'Original')
  17. ax2 = fig.add_subplot(122, aspect='equal')
  18. ax2.scatter(vqc1[:, 0], vqc1[:, 1], lw=0.5, color='#00CC00')
  19. ax2.scatter(vqc2[:, 0], vqc2[:, 1], lw=0.5, color='#028E9B')
  20. ax2.scatter(vqc3[:, 0], vqc3[:, 1], lw=0.5, color='#FF7800')
  21. ax2.xaxis.set_visible(False)
  22. ax2.yaxis.set_visible(False)
  23. ax2.set_xlim(x1, x2)
  24. ax2.set_ylim(y1, y2)
  25. ax2.text(-9, 8, 'VQ identified')
  26. return fig