This commit is contained in:
Wassim Kabalan 2024-12-08 22:45:09 +01:00
parent 7823fdaf98
commit af29c4005d
7 changed files with 68 additions and 63 deletions

View file

@ -18,7 +18,7 @@ import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import (process_allgather)
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P