mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
format
This commit is contained in:
parent
1f5c619531
commit
a8f11a75ea
1 changed files with 0 additions and 2 deletions
|
@ -3,7 +3,6 @@ import os
|
||||||
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||||
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||||
|
|
||||||
|
|
||||||
os.environ["EQX_ON_ERROR"] = "nan"
|
os.environ["EQX_ON_ERROR"] = "nan"
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
@ -25,7 +24,6 @@ from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
||||||
|
|
||||||
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
|
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
|
||||||
|
|
||||||
|
|
||||||
all_gather = partial(process_allgather, tiled=False)
|
all_gather = partial(process_allgather, tiled=False)
|
||||||
|
|
||||||
pdims = (2, 4)
|
pdims = (2, 4)
|
||||||
|
|
Loading…
Add table
Reference in a new issue