From 6aa2ccf7b5fc998aa37fd1351c9836dc49080c83 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Tue, 21 Jan 2025 11:18:00 +0100 Subject: [PATCH] don't send static cosmo into ODETerm --- notebooks/02-Advanced_usage.ipynb | 4 ++-- notebooks/03-MultiGPU_PM_Halo.ipynb | 4 ++-- notebooks/04-MultiGPU_PM_Solvers.ipynb | 4 ++-- notebooks/05-MultiHost_PM.py | 2 +- tests/test_gradients.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/notebooks/02-Advanced_usage.ipynb b/notebooks/02-Advanced_usage.ipynb index cf7f611..8125ec0 100644 --- a/notebooks/02-Advanced_usage.ipynb +++ b/notebooks/02-Advanced_usage.ipynb @@ -84,7 +84,7 @@ " \n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n", " solver = LeapfrogMidpoint()\n", "\n", " stepsize_controller = ConstantStepSize()\n", @@ -257,7 +257,7 @@ " \n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape))\n", + " make_diffrax_ode(mesh_shape))\n", " solver = LeapfrogMidpoint()\n", "\n", " stepsize_controller = ConstantStepSize()\n", diff --git a/notebooks/03-MultiGPU_PM_Halo.ipynb b/notebooks/03-MultiGPU_PM_Halo.ipynb index 0a652d2..28ffe1b 100644 --- a/notebooks/03-MultiGPU_PM_Halo.ipynb +++ b/notebooks/03-MultiGPU_PM_Halo.ipynb @@ -180,7 +180,7 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n", " solver = LeapfrogMidpoint()\n", "\n", " stepsize_controller = ConstantStepSize()\n", @@ -410,7 +410,7 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n", " solver = LeapfrogMidpoint()\n", "\n", " stepsize_controller = ConstantStepSize()\n", diff --git a/notebooks/04-MultiGPU_PM_Solvers.ipynb b/notebooks/04-MultiGPU_PM_Solvers.ipynb index 7671bc7..4ca9a31 100644 --- a/notebooks/04-MultiGPU_PM_Solvers.ipynb +++ b/notebooks/04-MultiGPU_PM_Solvers.ipynb @@ -124,7 +124,7 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n", " solver = LeapfrogMidpoint()\n", "\n", " stepsize_controller = ConstantStepSize()\n", @@ -288,7 +288,7 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n", " solver = Dopri5()\n", "\n", " stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n", diff --git a/notebooks/05-MultiHost_PM.py b/notebooks/05-MultiHost_PM.py index da3964e..5be7e60 100644 --- a/notebooks/05-MultiHost_PM.py +++ b/notebooks/05-MultiHost_PM.py @@ -106,7 +106,7 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) # Choose solver solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5() diff --git a/tests/test_gradients.py b/tests/test_gradients.py index bb48920..ad23224 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, particles, a=lpt_scale_factor, order=order) - ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) y0 = jnp.stack([particles + dx, p]) else: @@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, a=lpt_scale_factor, order=order) ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) y0 = jnp.stack([dx, p]) solver = Dopri5()