mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 17:47:11 +00:00
added things
This commit is contained in:
parent
ebab27b4c3
commit
70600fb7f7
2 changed files with 8 additions and 9 deletions
|
@ -147,24 +147,24 @@ def cic_paint(pos, mesh_shape, halo_size=0):
|
||||||
|
|
||||||
# Perform halo exchange
|
# Perform halo exchange
|
||||||
# Halo exchange along x
|
# Halo exchange along x
|
||||||
left = lax.pshuffle(mesh[-halo_size:],
|
left = lax.pshuffle(mesh[-2*halo_size:],
|
||||||
perm=range(mesh_size['nx'])[::-1],
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
axis_name='x')
|
axis_name='x')
|
||||||
right = lax.pshuffle(mesh[:halo_size],
|
right = lax.pshuffle(mesh[:2*halo_size],
|
||||||
perm=range(mesh_size['nx'])[::-1],
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
axis_name='x')
|
axis_name='x')
|
||||||
mesh = mesh.at[:halo_size].add(left)
|
mesh = mesh.at[:2*halo_size].add(left)
|
||||||
mesh = mesh.at[-halo_size:].add(right)
|
mesh = mesh.at[-2*halo_size:].add(right)
|
||||||
|
|
||||||
# Halo exchange along y
|
# Halo exchange along y
|
||||||
left = lax.pshuffle(mesh[:, -halo_size:],
|
left = lax.pshuffle(mesh[:, -2*halo_size:],
|
||||||
perm=range(mesh_size['ny'])[::-1],
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
axis_name='y')
|
axis_name='y')
|
||||||
right = lax.pshuffle(mesh[:, :halo_size],
|
right = lax.pshuffle(mesh[:, :2*halo_size],
|
||||||
perm=range(mesh_size['ny'])[::-1],
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
axis_name='y')
|
axis_name='y')
|
||||||
mesh = mesh.at[:, :halo_size].add(left)
|
mesh = mesh.at[:, :2*halo_size].add(left)
|
||||||
mesh = mesh.at[:, -halo_size:].add(right)
|
mesh = mesh.at[:, -2*halo_size:].add(right)
|
||||||
|
|
||||||
# removing halo and returning mesh
|
# removing halo and returning mesh
|
||||||
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||||
|
|
|
@ -64,7 +64,6 @@ def lpt(cosmo, initial_conditions, positions, a):
|
||||||
Computes first order LPT displacement
|
Computes first order LPT displacement
|
||||||
"""
|
"""
|
||||||
initial_force = pm_forces(positions, delta_k=initial_conditions)
|
initial_force = pm_forces(positions, delta_k=initial_conditions)
|
||||||
print(initial_force.shape)
|
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
|
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
|
||||||
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
|
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
|
||||||
|
|
Loading…
Add table
Reference in a new issue