mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-16 16:10:54 +00:00
kernels.py no longer uses global mesh
This commit is contained in:
parent
591ee32c55
commit
a5b267bd63
1 changed files with 31 additions and 29 deletions
|
@ -3,7 +3,6 @@ from enum import Enum
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax._src import mesh as mesh_lib
|
|
||||||
from jax.lib.xla_client import FftType
|
from jax.lib.xla_client import FftType
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
from jaxdecomp import fftfreq3d, get_output_specs
|
from jaxdecomp import fftfreq3d, get_output_specs
|
||||||
|
@ -11,7 +10,6 @@ from jaxdecomp import fftfreq3d, get_output_specs
|
||||||
from jaxpm.distributed import autoshmap
|
from jaxpm.distributed import autoshmap
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def fftk(k_array):
|
def fftk(k_array):
|
||||||
"""
|
"""
|
||||||
Generate Fourier transform wave numbers for a given mesh.
|
Generate Fourier transform wave numbers for a given mesh.
|
||||||
|
@ -28,31 +26,35 @@ def fftk(k_array):
|
||||||
return kx, ky, kz
|
return kx, ky, kz
|
||||||
|
|
||||||
|
|
||||||
def interpolate_power_spectrum(input, k, pk):
|
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||||
|
|
||||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
||||||
).reshape(x.shape)
|
).reshape(x.shape)
|
||||||
specs = P('x', 'y')
|
|
||||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
|
||||||
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh))
|
|
||||||
|
|
||||||
return autoshmap(pk_fn, in_specs=out_specs, out_specs=out_specs)(input)
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
specs = sharding.spec if sharding is not None else P()
|
||||||
|
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh=gpu_mesh))
|
||||||
|
|
||||||
|
return autoshmap(pk_fn,
|
||||||
|
gpu_mesh=gpu_mesh,
|
||||||
|
in_specs=out_specs,
|
||||||
|
out_specs=out_specs)(input)
|
||||||
|
|
||||||
|
|
||||||
def gradient_kernel(kvec, direction, order=1):
|
def gradient_kernel(kvec, direction, order=1):
|
||||||
"""
|
"""
|
||||||
Computes the gradient kernel in the requested direction
|
Computes the gradient kernel in the requested direction
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
kvec: array
|
kvec: array
|
||||||
Array of k values in Fourier space
|
Array of k values in Fourier space
|
||||||
direction: int
|
direction: int
|
||||||
Index of the direction in which to take the gradient
|
Index of the direction in which to take the gradient
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel
|
Complex kernel
|
||||||
"""
|
"""
|
||||||
if order == 0:
|
if order == 0:
|
||||||
wts = 1j * kvec[direction]
|
wts = 1j * kvec[direction]
|
||||||
wts = jnp.squeeze(wts)
|
wts = jnp.squeeze(wts)
|
||||||
|
@ -68,16 +70,16 @@ def gradient_kernel(kvec, direction, order=1):
|
||||||
|
|
||||||
def laplace_kernel(kvec):
|
def laplace_kernel(kvec):
|
||||||
"""
|
"""
|
||||||
Compute the Laplace kernel from a given K vector
|
Compute the Laplace kernel from a given K vector
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
kvec: array
|
kvec: array
|
||||||
Array of k values in Fourier space
|
Array of k values in Fourier space
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel
|
Complex kernel
|
||||||
"""
|
"""
|
||||||
kk = sum(ki**2 for ki in kvec)
|
kk = sum(ki**2 for ki in kvec)
|
||||||
wts = jnp.where(kk == 0, 1., 1. / kk)
|
wts = jnp.where(kk == 0, 1., 1. / kk)
|
||||||
return wts
|
return wts
|
||||||
|
|
Loading…
Add table
Reference in a new issue