From 0b08c6f59ab57617a33071d8d8a5ccffcff8bb64 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Wed, 26 Feb 2025 14:05:25 +0100 Subject: [PATCH] Use cosmo as arg for the ODE function --- jaxpm/pm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index e34d584..9951e1c 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape, return nbody_ode -def make_diffrax_ode(cosmo, - mesh_shape, +def make_diffrax_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None): @@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo, state is a tuple (position, velocities) """ pos, vel = state + cosmo = args forces = pm_forces(pos, mesh_shape=mesh_shape,