update PGDCorrection and neural ode to use new fft3d

This commit is contained in:
Wassim KABALAN 2024-10-25 10:14:29 +02:00
parent 505f2ec286
commit 5d4f438e92

View file

@ -43,10 +43,8 @@ def pm_forces(positions,
# Computes gravitational forces
forces = jnp.stack([
cic_read_dx(ifft3d(-gradient_kernel(kvec, i) * pot_k),
halo_size=halo_size,
sharding=sharding) for i in range(3)
],
axis=-1)
halo_size=halo_size,
sharding=sharding) for i in range(3)], axis=-1) # yapf: disable
return forces
@ -58,8 +56,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
"""
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding),
3)
local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable
displacement = autoshmap(
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
gpu_mesh=gpu_mesh,
@ -88,7 +85,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k)
shear_ii = fft3d(nabla_i_nabla_i * pot_k)
delta2 += shear_ii * shear_acc
shear_acc += shear_ii
@ -98,7 +95,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
kvec, j)
delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2
delta2 -= fft3d(nabla_i_nabla_j * pot_k)**2
delta_k2 = fft3d(delta2)
init_force2 = pm_forces(displacement,
@ -191,16 +188,16 @@ def pgd_correction(pos, mesh_shape, params):
pos: particle positions [npart, 3]
params: [alpha, kl, ks] pgd parameters
"""
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos)
delta_k = fft3d(delta)
kvec = fftk(delta_k)
alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta)
PGD_range = PGD_kernel(kvec, kl, ks)
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
forces_pgd = jnp.stack([
cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
for i in range(3)
],
axis=-1)
@ -217,11 +214,9 @@ def make_neural_ode_fn(model, mesh_shape):
state is a tuple (position, velocities)
"""
pos, vel = state
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos)
delta_k = jnp.fft.rfftn(delta)
delta_k = fft3d(delta)
kvec = fftk(delta_k)
# Computes gravitational potential
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
@ -233,7 +228,7 @@ def make_neural_ode_fn(model, mesh_shape):
# Computes gravitational forces
forces = jnp.stack([
cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k), pos)
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos)
for i in range(3)
],
axis=-1)