mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-04 11:10:53 +00:00
Update examples
This commit is contained in:
parent
0b08c6f59a
commit
8e0f300572
5 changed files with 132 additions and 95 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -71,20 +71,17 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax.experimental.mesh_utils import create_device_mesh\n",
|
||||
"from jax.experimental.multihost_utils import process_allgather\n",
|
||||
"from jax.sharding import Mesh, NamedSharding\n",
|
||||
"from jax.sharding import PartitionSpec as P\n",
|
||||
"from jax.sharding import PartitionSpec as P, NamedSharding\n",
|
||||
"\n",
|
||||
"all_gather = partial(process_allgather, tiled=False)\n",
|
||||
"\n",
|
||||
"pdims = (2, 4)\n",
|
||||
"devices = create_device_mesh(pdims)\n",
|
||||
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
|
||||
"mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n",
|
||||
"sharding = NamedSharding(mesh, P('x', 'y'))"
|
||||
]
|
||||
},
|
||||
|
@ -124,7 +121,7 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
|
||||
" solver = LeapfrogMidpoint()\n",
|
||||
"\n",
|
||||
" stepsize_controller = ConstantStepSize()\n",
|
||||
|
@ -288,7 +285,7 @@
|
|||
"\n",
|
||||
" # Evolve the simulation forward\n",
|
||||
" ode_fn = ODETerm(\n",
|
||||
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
|
||||
" make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))\n",
|
||||
" solver = Dopri5()\n",
|
||||
"\n",
|
||||
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
|
||||
|
|
|
@ -17,10 +17,8 @@ import jax_cosmo as jc
|
|||
import numpy as np
|
||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
PIDController, SaveAt, diffeqsolve)
|
||||
from jax.experimental.mesh_utils import create_device_mesh
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.sharding import PartitionSpec as P, NamedSharding
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx
|
||||
|
@ -77,8 +75,8 @@ def parse_arguments():
|
|||
|
||||
|
||||
def create_mesh_and_sharding(pdims):
|
||||
devices = create_device_mesh(pdims)
|
||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
|
||||
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
return mesh, sharding
|
||||
|
||||
|
@ -106,7 +104,7 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False , halo_size=halo_size , sharding=sharding))
|
||||
|
||||
# Choose solver
|
||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||
|
|
Loading…
Add table
Reference in a new issue