mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 12:31:11 +00:00
remove deprecated stuff
This commit is contained in:
parent
8e8e8964be
commit
0f833f0cb4
8 changed files with 96 additions and 831 deletions
|
@ -22,16 +22,16 @@ from jaxpm.kernels import interpolate_power_spectrum
|
|||
from jaxpm.painting import cic_paint_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
|
||||
size = 64
|
||||
size = 256
|
||||
mesh_shape = [size] * 3
|
||||
box_size = [float(size)] * 3
|
||||
snapshots = jnp.linspace(0.1, 1., 4)
|
||||
halo_size = 4
|
||||
halo_size = 32
|
||||
pdims = (1, 1)
|
||||
mesh = None
|
||||
sharding = None
|
||||
if jax.device_count() > 1:
|
||||
pdims = (2, 4)
|
||||
pdims = (2, 2)
|
||||
devices = mesh_utils.create_device_mesh(pdims)
|
||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
|
|
|
@ -27,7 +27,7 @@ def initialize_distributed():
|
|||
on_cluster = False
|
||||
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||
os.environ[
|
||||
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
|
||||
import jax
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue