mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
kernels.py no longer uses global mesh
This commit is contained in:
parent
591ee32c55
commit
a5b267bd63
1 changed files with 31 additions and 29 deletions
|
@ -3,7 +3,6 @@ from enum import Enum
|
|||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax.lib.xla_client import FftType
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jaxdecomp import fftfreq3d, get_output_specs
|
||||
|
@ -11,7 +10,6 @@ from jaxdecomp import fftfreq3d, get_output_specs
|
|||
from jaxpm.distributed import autoshmap
|
||||
|
||||
|
||||
|
||||
def fftk(k_array):
|
||||
"""
|
||||
Generate Fourier transform wave numbers for a given mesh.
|
||||
|
@ -28,31 +26,35 @@ def fftk(k_array):
|
|||
return kx, ky, kz
|
||||
|
||||
|
||||
def interpolate_power_spectrum(input, k, pk):
|
||||
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||
|
||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
||||
).reshape(x.shape)
|
||||
specs = P('x', 'y')
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh))
|
||||
|
||||
return autoshmap(pk_fn, in_specs=out_specs, out_specs=out_specs)(input)
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
specs = sharding.spec if sharding is not None else P()
|
||||
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh=gpu_mesh))
|
||||
|
||||
return autoshmap(pk_fn,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=out_specs,
|
||||
out_specs=out_specs)(input)
|
||||
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
"""
|
||||
Computes the gradient kernel in the requested direction
|
||||
Parameters:
|
||||
-----------
|
||||
kvec: array
|
||||
Array of k values in Fourier space
|
||||
direction: int
|
||||
Index of the direction in which to take the gradient
|
||||
Returns:
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel
|
||||
"""
|
||||
Computes the gradient kernel in the requested direction
|
||||
Parameters:
|
||||
-----------
|
||||
kvec: array
|
||||
Array of k values in Fourier space
|
||||
direction: int
|
||||
Index of the direction in which to take the gradient
|
||||
Returns:
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel
|
||||
"""
|
||||
if order == 0:
|
||||
wts = 1j * kvec[direction]
|
||||
wts = jnp.squeeze(wts)
|
||||
|
@ -68,16 +70,16 @@ def gradient_kernel(kvec, direction, order=1):
|
|||
|
||||
def laplace_kernel(kvec):
|
||||
"""
|
||||
Compute the Laplace kernel from a given K vector
|
||||
Parameters:
|
||||
-----------
|
||||
kvec: array
|
||||
Array of k values in Fourier space
|
||||
Returns:
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel
|
||||
"""
|
||||
Compute the Laplace kernel from a given K vector
|
||||
Parameters:
|
||||
-----------
|
||||
kvec: array
|
||||
Array of k values in Fourier space
|
||||
Returns:
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel
|
||||
"""
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
wts = jnp.where(kk == 0, 1., 1. / kk)
|
||||
return wts
|
||||
|
|
Loading…
Add table
Reference in a new issue