1234567891011121314151617181920212223242526272829 |
- import numpy as np
- import matplotlib.pyplot as plt
- import simnet as sn
- base_dir = './network_checkpoint_spring_mass_pointMass_2/val_domain/results/'
- # plot in 1d
- predicted_data = np.load(base_dir + 'Val_pred.npz', allow_pickle=True)
- true_data = np.load(base_dir + 'Val_true.npz', allow_pickle=True)
- true_data = np.atleast_1d(true_data.f.arr_0)[0]
- predicted_data = np.atleast_1d(predicted_data.f.arr_0)[0]
- print(predicted_data)
- print(true_data)
- plt.plot(true_data['t'], true_data['x1'], label='True x1')
- plt.plot(true_data['t'], true_data['x2'], label='True x2')
- plt.plot(true_data['t'], true_data['x3'], label='True x3')
- plt.plot(predicted_data['t'], predicted_data['x1'], label='Pred x1')
- plt.plot(predicted_data['t'], predicted_data['x2'], label='Pred x2')
- plt.plot(predicted_data['t'], predicted_data['x3'], label='Pred x3')
- plt.xlabel("Time")
- plt.ylabel("Displacement")
- plt.legend()
- plt.savefig("comparison_new.png")
|