From a0be772f3c5603324261985ab4456a5e5c4b877d Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 9 Jul 2024 02:35:13 +0200 Subject: [PATCH] Add transpose for single GPU --- jaxpm/_src/base_ops.py | 49 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/jaxpm/_src/base_ops.py b/jaxpm/_src/base_ops.py index 1c97557..8891328 100644 --- a/jaxpm/_src/base_ops.py +++ b/jaxpm/_src/base_ops.py @@ -5,6 +5,7 @@ from inspect import signature import jax import jax.numpy as jnp +import jax_cosmo as jc import jaxdecomp import numpy as np from jax.experimental.shard_map import shard_map @@ -22,7 +23,7 @@ class FFTOperator(CustomPartionedOperator): name = 'fftn' def single_gpu_impl(x): - return jnp.fft.fftn(x) + return jnp.fft.fftn(x).transpose([1, 2, 0]) def multi_gpu_impl(x): return pfft3d(x) @@ -33,7 +34,7 @@ class IFFTOperator(CustomPartionedOperator): name = 'ifftn' def single_gpu_impl(x): - return jnp.fft.ifftn(x) + return jnp.fft.ifftn(x).transpose([2, 0, 1]) def multi_gpu_impl(x): return pifft3d(x) @@ -216,6 +217,49 @@ class GenerateParticlesOperator(CallBackOperator): return base_sharding +class InterpolateICOperator(ShardedOperator): + + name = 'interpolate_ic' + + # TODO : find a way to allow using different transfer fn + def single_gpu_impl(kfield, kk, cosmo: jc.Cosmology, box_size): + + k = jnp.logspace(-4, 2, 128) # I don't understand why 256? + + mesh_shape = kfield.shape + pk = jc.power.linear_matter_power(cosmo, k) + pk = pk * (mesh_shape[0] / box_size[0]) * ( + mesh_shape[1] / box_size[1]) * (mesh_shape[2] / box_size[2]) + print(f"kk {kk.shape}") + print(f"kk.flatten() {kk.flatten().shape}") + delta_k = kfield * jc.scipy.interpolate.interp( + kk.flatten(), k, pk**0.5).reshape(kfield.shape) + + return delta_k + + def multi_gpu_impl(kfield, + kk, + cosmo: jc.Cosmology, + box_size, + k=jnp.logspace(-4, 2, 256)): + + mesh_shape = kfield.shape + pk = jc.power.linear_matter_power(cosmo, k) + pk = pk * (mesh_shape[0] / box_size[0]) * ( + mesh_shape[1] / box_size[1]) * (mesh_shape[2] / box_size[2]) + delta_k = kfield * jc.scipy.interpolate.interp( + kk.flatten(), k, pk**0.5).reshape(kfield.shape) + + return delta_k + + def infer_sharding_from_base_sharding(base_sharding): + + in_spec = base_sharding, base_sharding, P(), P(), P() + out_spec = base_sharding + + return in_spec, out_spec + + register_operator(FFTOperator) register_operator(IFFTOperator) register_operator(HaloExchangeOperator) @@ -224,3 +268,4 @@ register_operator(UnpaddingOperator) register_operator(NormalFieldOperator) register_operator(FFTKOperator) register_operator(GenerateParticlesOperator) +register_operator(InterpolateICOperator)