This commit is contained in:
Wassim Kabalan 2025-01-20 22:47:07 +01:00
parent 1f5c619531
commit a8f11a75ea

View file

@ -3,7 +3,6 @@ import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu" #os.environ["JAX_PLATFORM_NAME"] = "cpu"
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" #os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
os.environ["EQX_ON_ERROR"] = "nan" os.environ["EQX_ON_ERROR"] = "nan"
from functools import partial from functools import partial
@ -25,7 +24,6 @@ from jaxpm.pm import linear_field, lpt, make_diffrax_ode
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices" #assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
all_gather = partial(process_allgather, tiled=False) all_gather = partial(process_allgather, tiled=False)
pdims = (2, 4) pdims = (2, 4)