1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- from sympy import Symbol, Eq
- import numpy as np
- from simnet.solver import Solver
- from simnet.dataset import TrainDomain, ValidationDomain
- from simnet.data import Validation
- from simnet.sympy_utils.geometry_1d import Point1D
- from simnet.controller import SimNetController
- from simnet.plot_utils.vtk import var_to_vtk
- from spring_mass_ode import SpringMass
- # define time variable and range
- t_max = 10.0
- t_symbol = Symbol('t')
- x = Symbol('x')
- time_range = {t_symbol: (0, t_max)}
- geo = Point1D(0)
- class SpringMassTrain(TrainDomain):
- def __init__(self, **config):
- super(SpringMassTrain, self).__init__()
- # initial conditions
- IC = geo.boundary_bc(outvar_sympy={'x1': 1.,
- 'x2': 0,
- 'x3': 0,
- 'x1__t': 0,
- 'x2__t': 0,
- 'x3__t': 0},
- param_ranges={t_symbol: 0},
- batch_size_per_area=1)
- self.add(IC, name="IC")
- # solve over given time period
- interior = geo.boundary_bc(outvar_sympy={'ode_x1': 0.0,
- 'ode_x2': 0.0,
- 'ode_x3': 0.0},
- param_ranges=time_range,
- batch_size_per_area=500)
- self.add(interior, name="Interior")
- class SpringMassVal(ValidationDomain):
- def __init__(self, **config):
- super(SpringMassVal, self).__init__()
- deltaT = 0.001
- t = np.arange(0, t_max, deltaT)
- t = np.expand_dims(t, axis=-1)
- invar_numpy = {'t': t}
- outvar_numpy = {'x1': (1/6)*np.cos(t) + (1/2)*np.cos(np.sqrt(3)*t) + (1/3)*np.cos(2*t),
- 'x2': (2/6)*np.cos(t) + (0/2)*np.cos(np.sqrt(3)*t) - (1/3)*np.cos(2*t),
- 'x3': (1/6)*np.cos(t) - (1/2)*np.cos(np.sqrt(3)*t) + (1/3)*np.cos(2*t)}
- val = Validation.from_numpy(invar_numpy, outvar_numpy)
- self.add(val, name="Val")
- class SpringMassSolver(Solver):
- train_domain = SpringMassTrain
- val_domain = SpringMassVal
- def __init__(self, **config):
- super(SpringMassSolver, self).__init__(**config)
- self.equations = SpringMass(k=(2, 1, 1, 2), m=(1, 1, 1)).make_node()
- spring_net = self.arch.make_node(name='spring_net',
- inputs=['t'],
- outputs=['x1','x2','x3'])
- self.nets = [spring_net]
- @classmethod # Explain This
- def update_defaults(cls, defaults):
- defaults.update({
- 'network_dir': './network_checkpoint_spring_mass',
- 'max_steps': 10000,
- 'decay_steps': 100,
- 'nr_layers': 6,
- 'layer_size': 256,
- 'xla': True,
- })
- if __name__ == '__main__':
- ctr = SimNetController(SpringMassSolver)
- ctr.run()
|