mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
Add transpose for single GPU
This commit is contained in:
parent
d2fb1ee1e2
commit
a0be772f3c
1 changed files with 47 additions and 2 deletions
|
@ -5,6 +5,7 @@ from inspect import signature
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import jax_cosmo as jc
|
||||||
import jaxdecomp
|
import jaxdecomp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax.experimental.shard_map import shard_map
|
from jax.experimental.shard_map import shard_map
|
||||||
|
@ -22,7 +23,7 @@ class FFTOperator(CustomPartionedOperator):
|
||||||
name = 'fftn'
|
name = 'fftn'
|
||||||
|
|
||||||
def single_gpu_impl(x):
|
def single_gpu_impl(x):
|
||||||
return jnp.fft.fftn(x)
|
return jnp.fft.fftn(x).transpose([1, 2, 0])
|
||||||
|
|
||||||
def multi_gpu_impl(x):
|
def multi_gpu_impl(x):
|
||||||
return pfft3d(x)
|
return pfft3d(x)
|
||||||
|
@ -33,7 +34,7 @@ class IFFTOperator(CustomPartionedOperator):
|
||||||
name = 'ifftn'
|
name = 'ifftn'
|
||||||
|
|
||||||
def single_gpu_impl(x):
|
def single_gpu_impl(x):
|
||||||
return jnp.fft.ifftn(x)
|
return jnp.fft.ifftn(x).transpose([2, 0, 1])
|
||||||
|
|
||||||
def multi_gpu_impl(x):
|
def multi_gpu_impl(x):
|
||||||
return pifft3d(x)
|
return pifft3d(x)
|
||||||
|
@ -216,6 +217,49 @@ class GenerateParticlesOperator(CallBackOperator):
|
||||||
return base_sharding
|
return base_sharding
|
||||||
|
|
||||||
|
|
||||||
|
class InterpolateICOperator(ShardedOperator):
|
||||||
|
|
||||||
|
name = 'interpolate_ic'
|
||||||
|
|
||||||
|
# TODO : find a way to allow using different transfer fn
|
||||||
|
def single_gpu_impl(kfield, kk, cosmo: jc.Cosmology, box_size):
|
||||||
|
|
||||||
|
k = jnp.logspace(-4, 2, 128) # I don't understand why 256?
|
||||||
|
|
||||||
|
mesh_shape = kfield.shape
|
||||||
|
pk = jc.power.linear_matter_power(cosmo, k)
|
||||||
|
pk = pk * (mesh_shape[0] / box_size[0]) * (
|
||||||
|
mesh_shape[1] / box_size[1]) * (mesh_shape[2] / box_size[2])
|
||||||
|
print(f"kk {kk.shape}")
|
||||||
|
print(f"kk.flatten() {kk.flatten().shape}")
|
||||||
|
delta_k = kfield * jc.scipy.interpolate.interp(
|
||||||
|
kk.flatten(), k, pk**0.5).reshape(kfield.shape)
|
||||||
|
|
||||||
|
return delta_k
|
||||||
|
|
||||||
|
def multi_gpu_impl(kfield,
|
||||||
|
kk,
|
||||||
|
cosmo: jc.Cosmology,
|
||||||
|
box_size,
|
||||||
|
k=jnp.logspace(-4, 2, 256)):
|
||||||
|
|
||||||
|
mesh_shape = kfield.shape
|
||||||
|
pk = jc.power.linear_matter_power(cosmo, k)
|
||||||
|
pk = pk * (mesh_shape[0] / box_size[0]) * (
|
||||||
|
mesh_shape[1] / box_size[1]) * (mesh_shape[2] / box_size[2])
|
||||||
|
delta_k = kfield * jc.scipy.interpolate.interp(
|
||||||
|
kk.flatten(), k, pk**0.5).reshape(kfield.shape)
|
||||||
|
|
||||||
|
return delta_k
|
||||||
|
|
||||||
|
def infer_sharding_from_base_sharding(base_sharding):
|
||||||
|
|
||||||
|
in_spec = base_sharding, base_sharding, P(), P(), P()
|
||||||
|
out_spec = base_sharding
|
||||||
|
|
||||||
|
return in_spec, out_spec
|
||||||
|
|
||||||
|
|
||||||
register_operator(FFTOperator)
|
register_operator(FFTOperator)
|
||||||
register_operator(IFFTOperator)
|
register_operator(IFFTOperator)
|
||||||
register_operator(HaloExchangeOperator)
|
register_operator(HaloExchangeOperator)
|
||||||
|
@ -224,3 +268,4 @@ register_operator(UnpaddingOperator)
|
||||||
register_operator(NormalFieldOperator)
|
register_operator(NormalFieldOperator)
|
||||||
register_operator(FFTKOperator)
|
register_operator(FFTKOperator)
|
||||||
register_operator(GenerateParticlesOperator)
|
register_operator(GenerateParticlesOperator)
|
||||||
|
register_operator(InterpolateICOperator)
|
||||||
|
|
Loading…
Add table
Reference in a new issue