mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 01:57:10 +00:00
format
This commit is contained in:
parent
1f5c619531
commit
a8f11a75ea
1 changed files with 0 additions and 2 deletions
|
@ -3,7 +3,6 @@ import os
|
|||
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||
|
||||
|
||||
os.environ["EQX_ON_ERROR"] = "nan"
|
||||
from functools import partial
|
||||
|
||||
|
@ -25,7 +24,6 @@ from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
|||
|
||||
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
|
||||
|
||||
|
||||
all_gather = partial(process_allgather, tiled=False)
|
||||
|
||||
pdims = (2, 4)
|
||||
|
|
Loading…
Add table
Reference in a new issue