mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
apply formatting
This commit is contained in:
parent
11f7e90066
commit
4342279817
5 changed files with 26 additions and 22 deletions
|
@ -40,7 +40,7 @@ def ifft3d(x):
|
|||
def get_halo_size(halo_size, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
zero_ext = (0, 0, 0)
|
||||
zero_ext = (0, 0)
|
||||
zero_tuple = (0, 0)
|
||||
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
|
||||
else:
|
||||
|
|
|
@ -5,8 +5,8 @@ import jax.lax as lax
|
|||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange,
|
||||
slice_pad, slice_unpad, fft3d, ifft3d)
|
||||
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
|
||||
ifft3d, slice_pad, slice_unpad)
|
||||
from jaxpm.kernels import cic_compensation, fftk
|
||||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue