From d62c38f457b1b5c230502c671ee4e62ded829f18 Mon Sep 17 00:00:00 2001
From: Wassim KABALAN <wastondev@gmail.com>
Date: Sun, 27 Oct 2024 03:48:38 +0100
Subject: [PATCH] fix code in LPT2

---
 jaxpm/growth.py |  4 ++--
 jaxpm/pm.py     | 56 ++++++++++++++++++++-----------------------------
 2 files changed, 25 insertions(+), 35 deletions(-)

diff --git a/jaxpm/growth.py b/jaxpm/growth.py
index 5b6908c..8194b06 100644
--- a/jaxpm/growth.py
+++ b/jaxpm/growth.py
@@ -587,5 +587,5 @@ def dGf2a(cosmo, a):
     cache = cosmo._workspace['background.growth_factor']
     f2p = cache['h2'] / cache['a'] * cache['g2']
     f2p = interp(np.log(a), np.log(cache['a']), f2p)
-    E = E(cosmo, a)
-    return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
+    E_a = E(cosmo, a)
+    return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E_a * D2f)
diff --git a/jaxpm/pm.py b/jaxpm/pm.py
index e9bfaef..b41f261 100644
--- a/jaxpm/pm.py
+++ b/jaxpm/pm.py
@@ -1,12 +1,11 @@
 from functools import partial
 
-import jax
 import jax.numpy as jnp
 import jax_cosmo as jc
 from jax.sharding import PartitionSpec as P
 
 from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d,
-                               normal_field)
+                               normal_field,zeros)
 from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
                           growth_rate, growth_rate_second)
 from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
@@ -24,23 +23,24 @@ def pm_forces(positions,
     """
     Computes gravitational forces on particles using a PM scheme
     """
-    print(f"pm_forces particles are {positions}")
-    original_shape = positions.shape
     if mesh_shape is None:
         assert (delta is not None),\
           "If mesh_shape is not provided, delta should be provided"
         mesh_shape = delta.shape
 
-    positions = positions.reshape((*mesh_shape, 3))
     if paint_particles:
-        paint_fn = partial(cic_paint, grid_mesh=jnp.zeros(mesh_shape))
-        read_fn = partial(cic_read, positions=positions)
+        paint_fn = lambda x: cic_paint(
+            zeros(mesh_shape,sharding), x , halo_size=halo_size, sharding=sharding)
+        read_fn = lambda x: cic_read(
+            x, positions, halo_size=halo_size, sharding=sharding)
     else:
-        paint_fn = cic_paint_dx
-        read_fn = cic_read_dx
+        paint_fn = partial(cic_paint_dx,
+                           halo_size=halo_size,
+                           sharding=sharding)
+        read_fn = partial(cic_read_dx, halo_size=halo_size, sharding=sharding)
 
     if delta is None:
-        field = paint_fn(positions, halo_size=halo_size, sharding=sharding)
+        field = paint_fn(positions)
         delta_k = fft3d(field)
     elif jnp.isrealobj(delta):
         delta_k = fft3d(delta)
@@ -54,8 +54,7 @@ def pm_forces(positions,
     # Computes gravitational forces
     forces = jnp.stack([
         read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),
-        halo_size=halo_size,
-        sharding=sharding) for i in range(3)], axis=-1) # yapf: disable
+        ) for i in range(3)], axis=-1) # yapf: disable
 
     return forces
 
@@ -71,19 +70,7 @@ def lpt(cosmo,
     Computes first and second order LPT displacement and momentum,
     e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
     """
-    print(f"particles are {particles}")
-    gpu_mesh = sharding.mesh if sharding is not None else None
-    spec = sharding.spec if sharding is not None else P()
-    local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable
-    paint_particles = True
-    original_shape = particles.shape if particles is not None else (*initial_conditions.shape, 3) # yapf: disable
-    if particles is None:
-        paint_particles = False
-        particles = autoshmap(
-          partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
-          gpu_mesh=gpu_mesh,
-          in_specs=(),
-          out_specs=spec)()  # yapf: disable
+    paint_particles = particles is not None
 
     a = jnp.atleast_1d(a)
     E = jnp.sqrt(jc.background.Esqr(cosmo, a))
@@ -107,7 +94,7 @@ def lpt(cosmo,
             # Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
             # shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
             nabla_i_nabla_i = gradient_kernel(kvec, i)**2
-            shear_ii = fft3d(nabla_i_nabla_i * pot_k)
+            shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
             delta2 += shear_ii * shear_acc
             shear_acc += shear_ii
 
@@ -117,10 +104,10 @@ def lpt(cosmo,
                 # delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
                 nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
                     kvec, j)
-                delta2 -= fft3d(nabla_i_nabla_j * pot_k)**2
+                delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
 
         delta_k2 = fft3d(delta2)
-        init_force2 = pm_forces(displacement,
+        init_force2 = pm_forces(particles,
                                 delta=delta_k2,
                                 paint_particles=paint_particles,
                                 halo_size=halo_size,
@@ -134,7 +121,7 @@ def lpt(cosmo,
         p += p2
         f += f2
 
-    return dx.reshape(original_shape), p, f
+    return dx, p, f
 
 
 def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
@@ -155,17 +142,20 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
     return field
 
 
-def make_ode_fn(mesh_shape, halo_size=0, sharding=None):
+def make_ode_fn(mesh_shape, particles=None, halo_size=0, sharding=None):
 
     def nbody_ode(state, a, cosmo):
         """
         state is a tuple (position, velocities)
         """
         pos, vel = state
+        paint_particles = particles is not None
 
-        forces = pm_forces(
-            pos, mesh_shape=mesh_shape, halo_size=halo_size,
-            sharding=sharding) * 1.5 * cosmo.Omega_m
+        forces = pm_forces(pos,
+                           mesh_shape=mesh_shape,
+                           paint_particles=paint_particles,
+                           halo_size=halo_size,
+                           sharding=sharding) * 1.5 * cosmo.Omega_m
 
         # Computes the update of position (drift)
         dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel