mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
fix code in LPT2
This commit is contained in:
parent
cc4f310508
commit
d62c38f457
2 changed files with 25 additions and 35 deletions
|
@ -587,5 +587,5 @@ def dGf2a(cosmo, a):
|
|||
cache = cosmo._workspace['background.growth_factor']
|
||||
f2p = cache['h2'] / cache['a'] * cache['g2']
|
||||
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
||||
E = E(cosmo, a)
|
||||
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
|
||||
E_a = E(cosmo, a)
|
||||
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E_a * D2f)
|
||||
|
|
56
jaxpm/pm.py
56
jaxpm/pm.py
|
@ -1,12 +1,11 @@
|
|||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d,
|
||||
normal_field)
|
||||
normal_field,zeros)
|
||||
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
||||
growth_rate, growth_rate_second)
|
||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||
|
@ -24,23 +23,24 @@ def pm_forces(positions,
|
|||
"""
|
||||
Computes gravitational forces on particles using a PM scheme
|
||||
"""
|
||||
print(f"pm_forces particles are {positions}")
|
||||
original_shape = positions.shape
|
||||
if mesh_shape is None:
|
||||
assert (delta is not None),\
|
||||
"If mesh_shape is not provided, delta should be provided"
|
||||
mesh_shape = delta.shape
|
||||
|
||||
positions = positions.reshape((*mesh_shape, 3))
|
||||
if paint_particles:
|
||||
paint_fn = partial(cic_paint, grid_mesh=jnp.zeros(mesh_shape))
|
||||
read_fn = partial(cic_read, positions=positions)
|
||||
paint_fn = lambda x: cic_paint(
|
||||
zeros(mesh_shape,sharding), x , halo_size=halo_size, sharding=sharding)
|
||||
read_fn = lambda x: cic_read(
|
||||
x, positions, halo_size=halo_size, sharding=sharding)
|
||||
else:
|
||||
paint_fn = cic_paint_dx
|
||||
read_fn = cic_read_dx
|
||||
paint_fn = partial(cic_paint_dx,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
read_fn = partial(cic_read_dx, halo_size=halo_size, sharding=sharding)
|
||||
|
||||
if delta is None:
|
||||
field = paint_fn(positions, halo_size=halo_size, sharding=sharding)
|
||||
field = paint_fn(positions)
|
||||
delta_k = fft3d(field)
|
||||
elif jnp.isrealobj(delta):
|
||||
delta_k = fft3d(delta)
|
||||
|
@ -54,8 +54,7 @@ def pm_forces(positions,
|
|||
# Computes gravitational forces
|
||||
forces = jnp.stack([
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),
|
||||
halo_size=halo_size,
|
||||
sharding=sharding) for i in range(3)], axis=-1) # yapf: disable
|
||||
) for i in range(3)], axis=-1) # yapf: disable
|
||||
|
||||
return forces
|
||||
|
||||
|
@ -71,19 +70,7 @@ def lpt(cosmo,
|
|||
Computes first and second order LPT displacement and momentum,
|
||||
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
||||
"""
|
||||
print(f"particles are {particles}")
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable
|
||||
paint_particles = True
|
||||
original_shape = particles.shape if particles is not None else (*initial_conditions.shape, 3) # yapf: disable
|
||||
if particles is None:
|
||||
paint_particles = False
|
||||
particles = autoshmap(
|
||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(),
|
||||
out_specs=spec)() # yapf: disable
|
||||
paint_particles = particles is not None
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
|
@ -107,7 +94,7 @@ def lpt(cosmo,
|
|||
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
||||
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
|
||||
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
|
||||
shear_ii = fft3d(nabla_i_nabla_i * pot_k)
|
||||
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
|
||||
delta2 += shear_ii * shear_acc
|
||||
shear_acc += shear_ii
|
||||
|
||||
|
@ -117,10 +104,10 @@ def lpt(cosmo,
|
|||
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
|
||||
kvec, j)
|
||||
delta2 -= fft3d(nabla_i_nabla_j * pot_k)**2
|
||||
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
|
||||
|
||||
delta_k2 = fft3d(delta2)
|
||||
init_force2 = pm_forces(displacement,
|
||||
init_force2 = pm_forces(particles,
|
||||
delta=delta_k2,
|
||||
paint_particles=paint_particles,
|
||||
halo_size=halo_size,
|
||||
|
@ -134,7 +121,7 @@ def lpt(cosmo,
|
|||
p += p2
|
||||
f += f2
|
||||
|
||||
return dx.reshape(original_shape), p, f
|
||||
return dx, p, f
|
||||
|
||||
|
||||
def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
||||
|
@ -155,17 +142,20 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
|||
return field
|
||||
|
||||
|
||||
def make_ode_fn(mesh_shape, halo_size=0, sharding=None):
|
||||
def make_ode_fn(mesh_shape, particles=None, halo_size=0, sharding=None):
|
||||
|
||||
def nbody_ode(state, a, cosmo):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
paint_particles = particles is not None
|
||||
|
||||
forces = pm_forces(
|
||||
pos, mesh_shape=mesh_shape, halo_size=halo_size,
|
||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
paint_particles=paint_particles,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
|
Loading…
Add table
Reference in a new issue