diff --git a/tests/test_sharded_array.py b/tests/test_sharded_array.py index 93d23b6..47bc64a 100644 --- a/tests/test_sharded_array.py +++ b/tests/test_sharded_array.py @@ -3,7 +3,6 @@ import os #os.environ["JAX_PLATFORM_NAME"] = "cpu" #os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" - os.environ["EQX_ON_ERROR"] = "nan" from functools import partial @@ -25,7 +24,6 @@ from jaxpm.pm import linear_field, lpt, make_diffrax_ode #assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices" - all_gather = partial(process_allgather, tiled=False) pdims = (2, 4)