From 5d4f438e92a148a223f180216ace1672a4792267 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 25 Oct 2024 10:14:29 +0200 Subject: [PATCH] update PGDCorrection and neural ode to use new fft3d --- jaxpm/pm.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 3155467..321bf0c 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -43,10 +43,8 @@ def pm_forces(positions, # Computes gravitational forces forces = jnp.stack([ cic_read_dx(ifft3d(-gradient_kernel(kvec, i) * pot_k), - halo_size=halo_size, - sharding=sharding) for i in range(3) - ], - axis=-1) + halo_size=halo_size, + sharding=sharding) for i in range(3)], axis=-1) # yapf: disable return forces @@ -58,8 +56,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1): """ gpu_mesh = sharding.mesh if sharding is not None else None spec = sharding.spec if sharding is not None else P() - local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), - 3) + local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable displacement = autoshmap( partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), gpu_mesh=gpu_mesh, @@ -88,7 +85,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1): # Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)... # shear_ii = jnp.fft.irfftn(- ki**2 * pot_k) nabla_i_nabla_i = gradient_kernel(kvec, i)**2 - shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k) + shear_ii = fft3d(nabla_i_nabla_i * pot_k) delta2 += shear_ii * shear_acc shear_acc += shear_ii @@ -98,7 +95,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1): # delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2 nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel( kvec, j) - delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2 + delta2 -= fft3d(nabla_i_nabla_j * pot_k)**2 delta_k2 = fft3d(delta2) init_force2 = pm_forces(displacement, @@ -191,16 +188,16 @@ def pgd_correction(pos, mesh_shape, params): pos: particle positions [npart, 3] params: [alpha, kl, ks] pgd parameters """ - kvec = fftk(mesh_shape) delta = cic_paint(jnp.zeros(mesh_shape), pos) + delta_k = fft3d(delta) + kvec = fftk(delta_k) alpha, kl, ks = params - delta_k = jnp.fft.rfftn(delta) PGD_range = PGD_kernel(kvec, kl, ks) pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range forces_pgd = jnp.stack([ - cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k_pgd), pos) + cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos) for i in range(3) ], axis=-1) @@ -217,11 +214,9 @@ def make_neural_ode_fn(model, mesh_shape): state is a tuple (position, velocities) """ pos, vel = state - kvec = fftk(mesh_shape) - delta = cic_paint(jnp.zeros(mesh_shape), pos) - - delta_k = jnp.fft.rfftn(delta) + delta_k = fft3d(delta) + kvec = fftk(delta_k) # Computes gravitational potential pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, @@ -233,7 +228,7 @@ def make_neural_ode_fn(model, mesh_shape): # Computes gravitational forces forces = jnp.stack([ - cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k), pos) + cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos) for i in range(3) ], axis=-1)