mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
update PGDCorrection and neural ode to use new fft3d
This commit is contained in:
parent
505f2ec286
commit
5d4f438e92
1 changed files with 11 additions and 16 deletions
25
jaxpm/pm.py
25
jaxpm/pm.py
|
@ -44,9 +44,7 @@ def pm_forces(positions,
|
||||||
forces = jnp.stack([
|
forces = jnp.stack([
|
||||||
cic_read_dx(ifft3d(-gradient_kernel(kvec, i) * pot_k),
|
cic_read_dx(ifft3d(-gradient_kernel(kvec, i) * pot_k),
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding) for i in range(3)
|
sharding=sharding) for i in range(3)], axis=-1) # yapf: disable
|
||||||
],
|
|
||||||
axis=-1)
|
|
||||||
|
|
||||||
return forces
|
return forces
|
||||||
|
|
||||||
|
@ -58,8 +56,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
|
||||||
"""
|
"""
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
spec = sharding.spec if sharding is not None else P()
|
spec = sharding.spec if sharding is not None else P()
|
||||||
local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding),
|
local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable
|
||||||
3)
|
|
||||||
displacement = autoshmap(
|
displacement = autoshmap(
|
||||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||||
gpu_mesh=gpu_mesh,
|
gpu_mesh=gpu_mesh,
|
||||||
|
@ -88,7 +85,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
|
||||||
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
||||||
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
|
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
|
||||||
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
|
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
|
||||||
shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k)
|
shear_ii = fft3d(nabla_i_nabla_i * pot_k)
|
||||||
delta2 += shear_ii * shear_acc
|
delta2 += shear_ii * shear_acc
|
||||||
shear_acc += shear_ii
|
shear_acc += shear_ii
|
||||||
|
|
||||||
|
@ -98,7 +95,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
|
||||||
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
||||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
|
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
|
||||||
kvec, j)
|
kvec, j)
|
||||||
delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2
|
delta2 -= fft3d(nabla_i_nabla_j * pot_k)**2
|
||||||
|
|
||||||
delta_k2 = fft3d(delta2)
|
delta_k2 = fft3d(delta2)
|
||||||
init_force2 = pm_forces(displacement,
|
init_force2 = pm_forces(displacement,
|
||||||
|
@ -191,16 +188,16 @@ def pgd_correction(pos, mesh_shape, params):
|
||||||
pos: particle positions [npart, 3]
|
pos: particle positions [npart, 3]
|
||||||
params: [alpha, kl, ks] pgd parameters
|
params: [alpha, kl, ks] pgd parameters
|
||||||
"""
|
"""
|
||||||
kvec = fftk(mesh_shape)
|
|
||||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||||
|
delta_k = fft3d(delta)
|
||||||
|
kvec = fftk(delta_k)
|
||||||
alpha, kl, ks = params
|
alpha, kl, ks = params
|
||||||
delta_k = jnp.fft.rfftn(delta)
|
|
||||||
PGD_range = PGD_kernel(kvec, kl, ks)
|
PGD_range = PGD_kernel(kvec, kl, ks)
|
||||||
|
|
||||||
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
|
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
|
||||||
|
|
||||||
forces_pgd = jnp.stack([
|
forces_pgd = jnp.stack([
|
||||||
cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
||||||
for i in range(3)
|
for i in range(3)
|
||||||
],
|
],
|
||||||
axis=-1)
|
axis=-1)
|
||||||
|
@ -217,11 +214,9 @@ def make_neural_ode_fn(model, mesh_shape):
|
||||||
state is a tuple (position, velocities)
|
state is a tuple (position, velocities)
|
||||||
"""
|
"""
|
||||||
pos, vel = state
|
pos, vel = state
|
||||||
kvec = fftk(mesh_shape)
|
|
||||||
|
|
||||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||||
|
delta_k = fft3d(delta)
|
||||||
delta_k = jnp.fft.rfftn(delta)
|
kvec = fftk(delta_k)
|
||||||
|
|
||||||
# Computes gravitational potential
|
# Computes gravitational potential
|
||||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
|
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
|
||||||
|
@ -233,7 +228,7 @@ def make_neural_ode_fn(model, mesh_shape):
|
||||||
|
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
forces = jnp.stack([
|
forces = jnp.stack([
|
||||||
cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k), pos)
|
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos)
|
||||||
for i in range(3)
|
for i in range(3)
|
||||||
],
|
],
|
||||||
axis=-1)
|
axis=-1)
|
||||||
|
|
Loading…
Add table
Reference in a new issue