mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-11 21:50:55 +00:00
Use cosmo as arg for the ODE function
This commit is contained in:
parent
cb2a7ab17f
commit
0b08c6f59a
1 changed files with 2 additions and 2 deletions
|
@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
def make_diffrax_ode(cosmo,
|
def make_diffrax_ode(mesh_shape,
|
||||||
mesh_shape,
|
|
||||||
paint_absolute_pos=True,
|
paint_absolute_pos=True,
|
||||||
halo_size=0,
|
halo_size=0,
|
||||||
sharding=None):
|
sharding=None):
|
||||||
|
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
|
||||||
state is a tuple (position, velocities)
|
state is a tuple (position, velocities)
|
||||||
"""
|
"""
|
||||||
pos, vel = state
|
pos, vel = state
|
||||||
|
cosmo = args
|
||||||
|
|
||||||
forces = pm_forces(pos,
|
forces = pm_forces(pos,
|
||||||
mesh_shape=mesh_shape,
|
mesh_shape=mesh_shape,
|
||||||
|
|
Loading…
Add table
Reference in a new issue