mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-19 01:20:55 +00:00
kernels.py no longer uses global mesh
This commit is contained in:
parent
591ee32c55
commit
a5b267bd63
1 changed files with 31 additions and 29 deletions
|
@ -3,7 +3,6 @@ from enum import Enum
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax._src import mesh as mesh_lib
|
|
||||||
from jax.lib.xla_client import FftType
|
from jax.lib.xla_client import FftType
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
from jaxdecomp import fftfreq3d, get_output_specs
|
from jaxdecomp import fftfreq3d, get_output_specs
|
||||||
|
@ -11,7 +10,6 @@ from jaxdecomp import fftfreq3d, get_output_specs
|
||||||
from jaxpm.distributed import autoshmap
|
from jaxpm.distributed import autoshmap
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def fftk(k_array):
|
def fftk(k_array):
|
||||||
"""
|
"""
|
||||||
Generate Fourier transform wave numbers for a given mesh.
|
Generate Fourier transform wave numbers for a given mesh.
|
||||||
|
@ -28,15 +26,19 @@ def fftk(k_array):
|
||||||
return kx, ky, kz
|
return kx, ky, kz
|
||||||
|
|
||||||
|
|
||||||
def interpolate_power_spectrum(input, k, pk):
|
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||||
|
|
||||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
||||||
).reshape(x.shape)
|
).reshape(x.shape)
|
||||||
specs = P('x', 'y')
|
|
||||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
|
||||||
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh))
|
|
||||||
|
|
||||||
return autoshmap(pk_fn, in_specs=out_specs, out_specs=out_specs)(input)
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
specs = sharding.spec if sharding is not None else P()
|
||||||
|
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh=gpu_mesh))
|
||||||
|
|
||||||
|
return autoshmap(pk_fn,
|
||||||
|
gpu_mesh=gpu_mesh,
|
||||||
|
in_specs=out_specs,
|
||||||
|
out_specs=out_specs)(input)
|
||||||
|
|
||||||
|
|
||||||
def gradient_kernel(kvec, direction, order=1):
|
def gradient_kernel(kvec, direction, order=1):
|
||||||
|
|
Loading…
Add table
Reference in a new issue