mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-30 00:51:11 +00:00
Fix sharding error (#37)
* Use cosmo as arg for the ODE function * Update examples * format * notebook update * fix tests * add correct annotations for weights in painting and warning for cic_paint in distributed pm * update test_against_fpm * update distributed tests and add jacfwd jacrev and vmap tests * format * add Caveats to notebook readme * final touches * update Growth.py to allow using FastPM solver * fix 2D painting when input is (X , Y , 2) shape * update cic read halo size and notebooks examples * Allow env variable control of caching in growth * Format * update test jax version * update notebooks/03-MultiGPU_PM_Halo.ipynb * update numpy install in wf * update tolerance :) * reorganize install in test workflow * update tests * add mpi4py * update tests.yml * update tests * update wf * format * make normal_field signature consistent with jax.random.normal * update by default normal_field dtype to match JAX * format * debug test workflow * format * debug test workflow * updating tests * fix accuracy * fixed tolerance * adding caching * Update conftest.py * Update tolerance and precision settings in distributed PM tests * revererting back changes to growth.py --------- Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
This commit is contained in:
parent
cb2a7ab17f
commit
6693e5c725
17 changed files with 675 additions and 298 deletions
|
@ -17,9 +17,8 @@ import jax_cosmo as jc
|
|||
import numpy as np
|
||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
PIDController, SaveAt, diffeqsolve)
|
||||
from jax.experimental.mesh_utils import create_device_mesh
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
|
@ -78,7 +77,7 @@ def parse_arguments():
|
|||
|
||||
def create_mesh_and_sharding(pdims):
|
||||
devices = create_device_mesh(pdims)
|
||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
return mesh, sharding
|
||||
|
||||
|
@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
sharding=sharding,
|
||||
halo_size=halo_size))
|
||||
|
||||
# Choose solver
|
||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue