plot_results_spring.py 928 B

1234567891011121314151617181920212223242526272829
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import simnet as sn
  4. base_dir = './network_checkpoint_spring_mass_pointMass_2/val_domain/results/'
  5. # plot in 1d
  6. predicted_data = np.load(base_dir + 'Val_pred.npz', allow_pickle=True)
  7. true_data = np.load(base_dir + 'Val_true.npz', allow_pickle=True)
  8. true_data = np.atleast_1d(true_data.f.arr_0)[0]
  9. predicted_data = np.atleast_1d(predicted_data.f.arr_0)[0]
  10. print(predicted_data)
  11. print(true_data)
  12. plt.plot(true_data['t'], true_data['x1'], label='True x1')
  13. plt.plot(true_data['t'], true_data['x2'], label='True x2')
  14. plt.plot(true_data['t'], true_data['x3'], label='True x3')
  15. plt.plot(predicted_data['t'], predicted_data['x1'], label='Pred x1')
  16. plt.plot(predicted_data['t'], predicted_data['x2'], label='Pred x2')
  17. plt.plot(predicted_data['t'], predicted_data['x3'], label='Pred x3')
  18. plt.xlabel("Time")
  19. plt.ylabel("Displacement")
  20. plt.legend()
  21. plt.savefig("comparison_new.png")