mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-15 10:21:11 +00:00
Merge 8e0f300572
into cb2a7ab17f
This commit is contained in:
commit
9b0e988910
6 changed files with 134 additions and 97 deletions
|
@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape,
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
def make_diffrax_ode(cosmo,
|
def make_diffrax_ode(mesh_shape,
|
||||||
mesh_shape,
|
|
||||||
paint_absolute_pos=True,
|
paint_absolute_pos=True,
|
||||||
halo_size=0,
|
halo_size=0,
|
||||||
sharding=None):
|
sharding=None):
|
||||||
|
@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo,
|
||||||
state is a tuple (position, velocities)
|
state is a tuple (position, velocities)
|
||||||
"""
|
"""
|
||||||
pos, vel = state
|
pos, vel = state
|
||||||
|
cosmo = args
|
||||||
|
|
||||||
forces = pm_forces(pos,
|
forces = pm_forces(pos,
|
||||||
mesh_shape=mesh_shape,
|
mesh_shape=mesh_shape,
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -71,20 +71,17 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from jax.experimental.mesh_utils import create_device_mesh\n",
|
|
||||||
"from jax.experimental.multihost_utils import process_allgather\n",
|
"from jax.experimental.multihost_utils import process_allgather\n",
|
||||||
"from jax.sharding import Mesh, NamedSharding\n",
|
"from jax.sharding import PartitionSpec as P, NamedSharding\n",
|
||||||
"from jax.sharding import PartitionSpec as P\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"all_gather = partial(process_allgather, tiled=False)\n",
|
"all_gather = partial(process_allgather, tiled=False)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"pdims = (2, 4)\n",
|
"pdims = (2, 4)\n",
|
||||||
"devices = create_device_mesh(pdims)\n",
|
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
|
||||||
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
|
|
||||||
"sharding = NamedSharding(mesh, P('x', 'y'))"
|
"sharding = NamedSharding(mesh, P('x', 'y'))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -124,7 +121,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = ODETerm(\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
|
||||||
" solver = LeapfrogMidpoint()\n",
|
" solver = LeapfrogMidpoint()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = ConstantStepSize()\n",
|
" stepsize_controller = ConstantStepSize()\n",
|
||||||
|
@ -288,7 +285,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
" # Evolve the simulation forward\n",
|
" # Evolve the simulation forward\n",
|
||||||
" ode_fn = ODETerm(\n",
|
" ode_fn = ODETerm(\n",
|
||||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
|
||||||
" solver = Dopri5()\n",
|
" solver = Dopri5()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
||||||
|
|
|
@ -17,10 +17,8 @@ import jax_cosmo as jc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||||
PIDController, SaveAt, diffeqsolve)
|
PIDController, SaveAt, diffeqsolve)
|
||||||
from jax.experimental.mesh_utils import create_device_mesh
|
|
||||||
from jax.experimental.multihost_utils import process_allgather
|
from jax.experimental.multihost_utils import process_allgather
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import PartitionSpec as P, NamedSharding
|
||||||
from jax.sharding import PartitionSpec as P
|
|
||||||
|
|
||||||
from jaxpm.kernels import interpolate_power_spectrum
|
from jaxpm.kernels import interpolate_power_spectrum
|
||||||
from jaxpm.painting import cic_paint_dx
|
from jaxpm.painting import cic_paint_dx
|
||||||
|
@ -77,8 +75,8 @@ def parse_arguments():
|
||||||
|
|
||||||
|
|
||||||
def create_mesh_and_sharding(pdims):
|
def create_mesh_and_sharding(pdims):
|
||||||
devices = create_device_mesh(pdims)
|
|
||||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
|
||||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||||
return mesh, sharding
|
return mesh, sharding
|
||||||
|
|
||||||
|
@ -106,7 +104,7 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
|
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(
|
||||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))
|
||||||
|
|
||||||
# Choose solver
|
# Choose solver
|
||||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue