mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
don't send static cosmo into ODETerm
This commit is contained in:
parent
6ab26ea1ec
commit
6aa2ccf7b5
5 changed files with 9 additions and 9 deletions
|
@ -84,7 +84,7 @@
|
||||||
" \n",
|
" \n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = ODETerm(\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n",
|
||||||
" solver = LeapfrogMidpoint()\n",
|
" solver = LeapfrogMidpoint()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = ConstantStepSize()\n",
|
" stepsize_controller = ConstantStepSize()\n",
|
||||||
|
@ -257,7 +257,7 @@
|
||||||
" \n",
|
" \n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = ODETerm(\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" make_diffrax_ode(cosmo, mesh_shape))\n",
|
" make_diffrax_ode(mesh_shape))\n",
|
||||||
" solver = LeapfrogMidpoint()\n",
|
" solver = LeapfrogMidpoint()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = ConstantStepSize()\n",
|
" stepsize_controller = ConstantStepSize()\n",
|
||||||
|
|
|
@ -180,7 +180,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = ODETerm(\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n",
|
||||||
" solver = LeapfrogMidpoint()\n",
|
" solver = LeapfrogMidpoint()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = ConstantStepSize()\n",
|
" stepsize_controller = ConstantStepSize()\n",
|
||||||
|
@ -410,7 +410,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = ODETerm(\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n",
|
||||||
" solver = LeapfrogMidpoint()\n",
|
" solver = LeapfrogMidpoint()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = ConstantStepSize()\n",
|
" stepsize_controller = ConstantStepSize()\n",
|
||||||
|
|
|
@ -124,7 +124,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = ODETerm(\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n",
|
||||||
" solver = LeapfrogMidpoint()\n",
|
" solver = LeapfrogMidpoint()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = ConstantStepSize()\n",
|
" stepsize_controller = ConstantStepSize()\n",
|
||||||
|
@ -288,7 +288,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = ODETerm(\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n",
|
||||||
" solver = Dopri5()\n",
|
" solver = Dopri5()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
||||||
|
|
|
@ -106,7 +106,7 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
|
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(
|
||||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||||
|
|
||||||
# Choose solver
|
# Choose solver
|
||||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||||
|
|
|
@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
particles,
|
particles,
|
||||||
a=lpt_scale_factor,
|
a=lpt_scale_factor,
|
||||||
order=order)
|
order=order)
|
||||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||||
y0 = jnp.stack([particles + dx, p])
|
y0 = jnp.stack([particles + dx, p])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
a=lpt_scale_factor,
|
a=lpt_scale_factor,
|
||||||
order=order)
|
order=order)
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(
|
||||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||||
y0 = jnp.stack([dx, p])
|
y0 = jnp.stack([dx, p])
|
||||||
|
|
||||||
solver = Dopri5()
|
solver = Dopri5()
|
||||||
|
|
Loading…
Add table
Reference in a new issue