remove deprecated stuff

This commit is contained in:
Francois Lanusse 2024-10-24 16:36:41 -04:00 committed by Wassim KABALAN
parent 8e8e8964be
commit 0f833f0cb4
8 changed files with 96 additions and 831 deletions

View file

@ -22,16 +22,16 @@ from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
size = 64
size = 256
mesh_shape = [size] * 3
box_size = [float(size)] * 3
snapshots = jnp.linspace(0.1, 1., 4)
halo_size = 4
halo_size = 32
pdims = (1, 1)
mesh = None
sharding = None
if jax.device_count() > 1:
pdims = (2, 4)
pdims = (2, 2)
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))

View file

@ -27,7 +27,7 @@ def initialize_distributed():
on_cluster = False
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ[
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
import jax