This commit is contained in:
Wassim KABALAN 2024-04-19 10:32:38 +02:00
parent 055ceedb7e
commit 179030377b
4 changed files with 63 additions and 38 deletions

View file

@ -50,9 +50,8 @@ def pm_forces(mesh , positions, mesh_shape=None, delta_k=None, halo_size=0, shar
force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info)
forces.append(force)
print(f"Shape {ifft_forces.shape}")
return jnp.stack(forces)
return jnp.stack(forces , axis=-1)
@ -64,7 +63,6 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
a = jnp.atleast_1d(a)
print(f"Shape initial {initial_conditions.shape}")
@jax.jit
def compute_dx(cosmo , i_force):
@ -85,6 +83,8 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
p = compute_p(cosmo , dx)
f = compute_f(cosmo , initial_force)
return dx, p, f