mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
pm ok
This commit is contained in:
parent
055ceedb7e
commit
179030377b
4 changed files with 63 additions and 38 deletions
|
@ -50,9 +50,8 @@ def pm_forces(mesh , positions, mesh_shape=None, delta_k=None, halo_size=0, shar
|
|||
|
||||
force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info)
|
||||
forces.append(force)
|
||||
print(f"Shape {ifft_forces.shape}")
|
||||
|
||||
return jnp.stack(forces)
|
||||
return jnp.stack(forces , axis=-1)
|
||||
|
||||
|
||||
|
||||
|
@ -64,7 +63,6 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
|
|||
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
|
||||
a = jnp.atleast_1d(a)
|
||||
|
||||
print(f"Shape initial {initial_conditions.shape}")
|
||||
|
||||
@jax.jit
|
||||
def compute_dx(cosmo , i_force):
|
||||
|
@ -85,6 +83,8 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
|
|||
p = compute_p(cosmo , dx)
|
||||
f = compute_f(cosmo , initial_force)
|
||||
|
||||
|
||||
|
||||
return dx, p, f
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue