mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-14 12:01:12 +00:00
final touches
This commit is contained in:
parent
4e4d3745f0
commit
e1daa8cba4
3 changed files with 5 additions and 4 deletions
|
@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
particles,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
|
||||
else:
|
||||
|
@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
|
@ -66,6 +66,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
adjoint=adjoint,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue