spring_mass_solver.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from sympy import Symbol, Eq
  2. import numpy as np
  3. from simnet.solver import Solver
  4. from simnet.dataset import TrainDomain, ValidationDomain
  5. from simnet.data import Validation
  6. from simnet.sympy_utils.geometry_1d import Point1D
  7. from simnet.controller import SimNetController
  8. from simnet.plot_utils.vtk import var_to_vtk
  9. from spring_mass_ode import SpringMass
  10. # define time variable and range
  11. t_max = 10.0
  12. t_symbol = Symbol('t')
  13. x = Symbol('x')
  14. time_range = {t_symbol: (0, t_max)}
  15. geo = Point1D(0)
  16. class SpringMassTrain(TrainDomain):
  17. def __init__(self, **config):
  18. super(SpringMassTrain, self).__init__()
  19. # initial conditions
  20. IC = geo.boundary_bc(outvar_sympy={'x1': 1.,
  21. 'x2': 0,
  22. 'x3': 0,
  23. 'x1__t': 0,
  24. 'x2__t': 0,
  25. 'x3__t': 0},
  26. param_ranges={t_symbol: 0},
  27. batch_size_per_area=1)
  28. self.add(IC, name="IC")
  29. # solve over given time period
  30. interior = geo.boundary_bc(outvar_sympy={'ode_x1': 0.0,
  31. 'ode_x2': 0.0,
  32. 'ode_x3': 0.0},
  33. param_ranges=time_range,
  34. batch_size_per_area=500)
  35. self.add(interior, name="Interior")
  36. class SpringMassVal(ValidationDomain):
  37. def __init__(self, **config):
  38. super(SpringMassVal, self).__init__()
  39. deltaT = 0.001
  40. t = np.arange(0, t_max, deltaT)
  41. t = np.expand_dims(t, axis=-1)
  42. invar_numpy = {'t': t}
  43. outvar_numpy = {'x1': (1/6)*np.cos(t) + (1/2)*np.cos(np.sqrt(3)*t) + (1/3)*np.cos(2*t),
  44. 'x2': (2/6)*np.cos(t) + (0/2)*np.cos(np.sqrt(3)*t) - (1/3)*np.cos(2*t),
  45. 'x3': (1/6)*np.cos(t) - (1/2)*np.cos(np.sqrt(3)*t) + (1/3)*np.cos(2*t)}
  46. val = Validation.from_numpy(invar_numpy, outvar_numpy)
  47. self.add(val, name="Val")
  48. class SpringMassSolver(Solver):
  49. train_domain = SpringMassTrain
  50. val_domain = SpringMassVal
  51. def __init__(self, **config):
  52. super(SpringMassSolver, self).__init__(**config)
  53. self.equations = SpringMass(k=(2, 1, 1, 2), m=(1, 1, 1)).make_node()
  54. spring_net = self.arch.make_node(name='spring_net',
  55. inputs=['t'],
  56. outputs=['x1','x2','x3'])
  57. self.nets = [spring_net]
  58. @classmethod # Explain This
  59. def update_defaults(cls, defaults):
  60. defaults.update({
  61. 'network_dir': './network_checkpoint_spring_mass',
  62. 'max_steps': 10000,
  63. 'decay_steps': 100,
  64. 'nr_layers': 6,
  65. 'layer_size': 256,
  66. 'xla': True,
  67. })
  68. if __name__ == '__main__':
  69. ctr = SimNetController(SpringMassSolver)
  70. ctr.run()