mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-04 11:10:53 +00:00
final touches
This commit is contained in:
parent
4e4d3745f0
commit
e1daa8cba4
3 changed files with 5 additions and 4 deletions
|
@ -95,6 +95,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
||||||
t1=1.0,
|
t1=1.0,
|
||||||
dt0=None,
|
dt0=None,
|
||||||
y0=y0,
|
y0=y0,
|
||||||
|
args=cosmo,
|
||||||
stepsize_controller=controller,
|
stepsize_controller=controller,
|
||||||
saveat=saveat)
|
saveat=saveat)
|
||||||
|
|
||||||
|
@ -140,6 +141,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
||||||
t1=1.0,
|
t1=1.0,
|
||||||
dt0=None,
|
dt0=None,
|
||||||
y0=y0,
|
y0=y0,
|
||||||
|
args=cosmo,
|
||||||
stepsize_controller=controller,
|
stepsize_controller=controller,
|
||||||
saveat=saveat)
|
saveat=saveat)
|
||||||
|
|
||||||
|
|
|
@ -258,7 +258,6 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
|
||||||
def test_fwd_rev_gradients(cosmo, pdims):
|
def test_fwd_rev_gradients(cosmo, pdims):
|
||||||
|
|
||||||
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||||
# SINGLE DEVICE RUN
|
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
|
|
||||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||||
|
@ -328,7 +327,6 @@ def test_fwd_rev_gradients(cosmo, pdims):
|
||||||
def test_vmap(cosmo, pdims):
|
def test_vmap(cosmo, pdims):
|
||||||
|
|
||||||
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||||
# SINGLE DEVICE RUN
|
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
|
|
||||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||||
|
|
|
@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
particles,
|
particles,
|
||||||
a=lpt_scale_factor,
|
a=lpt_scale_factor,
|
||||||
order=order)
|
order=order)
|
||||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||||
y0 = jnp.stack([particles + dx, p])
|
y0 = jnp.stack([particles + dx, p])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
a=lpt_scale_factor,
|
a=lpt_scale_factor,
|
||||||
order=order)
|
order=order)
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(
|
||||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||||
y0 = jnp.stack([dx, p])
|
y0 = jnp.stack([dx, p])
|
||||||
|
|
||||||
solver = Dopri5()
|
solver = Dopri5()
|
||||||
|
@ -66,6 +66,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
t1=1.0,
|
t1=1.0,
|
||||||
dt0=None,
|
dt0=None,
|
||||||
y0=y0,
|
y0=y0,
|
||||||
|
args=cosmo,
|
||||||
adjoint=adjoint,
|
adjoint=adjoint,
|
||||||
stepsize_controller=controller,
|
stepsize_controller=controller,
|
||||||
saveat=saveat)
|
saveat=saveat)
|
||||||
|
|
Loading…
Add table
Reference in a new issue