mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
format
This commit is contained in:
parent
86233081e2
commit
82b8f563a0
4 changed files with 60 additions and 44 deletions
|
@ -84,8 +84,8 @@ def invlaplace_kernel(kvec):
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
"""
|
"""
|
||||||
kk = sum(ki**2 for ki in kvec)
|
kk = sum(ki**2 for ki in kvec)
|
||||||
kk_nozeros = jnp.where(kk==0, 1, kk)
|
kk_nozeros = jnp.where(kk == 0, 1, kk)
|
||||||
return - jnp.where(kk==0, 0, 1 / kk_nozeros)
|
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
|
||||||
|
|
||||||
|
|
||||||
def longrange_kernel(kvec, r_split):
|
def longrange_kernel(kvec, r_split):
|
||||||
|
|
66
jaxpm/pm.py
66
jaxpm/pm.py
|
@ -9,8 +9,8 @@ from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d,
|
||||||
normal_field)
|
normal_field)
|
||||||
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
||||||
growth_rate, growth_rate_second)
|
growth_rate, growth_rate_second)
|
||||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, invlaplace_kernel,
|
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||||
longrange_kernel)
|
invlaplace_kernel, longrange_kernel)
|
||||||
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,11 +38,11 @@ def pm_forces(positions,
|
||||||
|
|
||||||
kvec = fftk(delta_k)
|
kvec = fftk(delta_k)
|
||||||
# Computes gravitational potential
|
# Computes gravitational potential
|
||||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
|
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
|
||||||
r_split=r_split)
|
kvec, r_split=r_split)
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
forces = jnp.stack([
|
forces = jnp.stack([
|
||||||
cic_read_dx(ifft3d( - gradient_kernel(kvec, i) * pot_k),
|
cic_read_dx(ifft3d(-gradient_kernel(kvec, i) * pot_k),
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding) for i in range(3)
|
sharding=sharding) for i in range(3)
|
||||||
],
|
],
|
||||||
|
@ -51,7 +51,7 @@ def pm_forces(positions,
|
||||||
return forces
|
return forces
|
||||||
|
|
||||||
|
|
||||||
def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None,order=1):
|
def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
|
||||||
"""
|
"""
|
||||||
Computes first and second order LPT displacement and momentum,
|
Computes first and second order LPT displacement and momentum,
|
||||||
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
||||||
|
@ -76,7 +76,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None,order=1):
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
dx = growth_factor(cosmo, a) * initial_force
|
dx = growth_factor(cosmo, a) * initial_force
|
||||||
p = a**2 * growth_rate(cosmo, a) * E * dx
|
p = a**2 * growth_rate(cosmo, a) * E * dx
|
||||||
f = a**2 * E * dGfa(cosmo,a) * initial_force
|
f = a**2 * E * dGfa(cosmo, a) * initial_force
|
||||||
if order == 2:
|
if order == 2:
|
||||||
kvec = fftk(delta_k)
|
kvec = fftk(delta_k)
|
||||||
pot_k = delta_k * invlaplace_kernel(kvec)
|
pot_k = delta_k * invlaplace_kernel(kvec)
|
||||||
|
@ -93,22 +93,26 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None,order=1):
|
||||||
shear_acc += shear_ii
|
shear_acc += shear_ii
|
||||||
|
|
||||||
# for kj in kvec[i+1:]:
|
# for kj in kvec[i+1:]:
|
||||||
for j in range(i+1, 3):
|
for j in range(i + 1, 3):
|
||||||
# Substract squared strict-up-triangle terms
|
# Substract squared strict-up-triangle terms
|
||||||
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
||||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(kvec, j)
|
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
|
||||||
|
kvec, j)
|
||||||
delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2
|
delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2
|
||||||
|
|
||||||
delta_k2 = fft3d(delta2)
|
delta_k2 = fft3d(delta2)
|
||||||
init_force2 = pm_forces(displacement, delta=delta_k2,halo_size=halo_size,sharding=sharding)
|
init_force2 = pm_forces(displacement,
|
||||||
|
delta=delta_k2,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
||||||
dx2 = 3/7 * growth_factor_second(cosmo, a) * init_force2
|
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
|
||||||
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
|
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
|
||||||
f2 = a**2 * E * dGf2a(cosmo, a) * init_force2
|
f2 = a**2 * E * dGf2a(cosmo, a) * init_force2
|
||||||
|
|
||||||
dx += dx2
|
dx += dx2
|
||||||
p += p2
|
p += p2
|
||||||
f += f2
|
f += f2
|
||||||
|
|
||||||
return dx, p, f
|
return dx, p, f
|
||||||
|
|
||||||
|
@ -153,6 +157,7 @@ def make_ode_fn(mesh_shape, halo_size=0, sharding=None):
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
|
def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
|
||||||
|
|
||||||
def nbody_ode(a, state, args):
|
def nbody_ode(a, state, args):
|
||||||
|
@ -162,7 +167,9 @@ def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
|
||||||
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
|
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
|
||||||
"""
|
"""
|
||||||
pos, vel = state
|
pos, vel = state
|
||||||
forces = pm_forces(pos, mesh_shape, halo_size=halo_size, sharding=sharding) * 1.5 * cosmo.Omega_m
|
forces = pm_forces(
|
||||||
|
pos, mesh_shape, halo_size=halo_size,
|
||||||
|
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
# Computes the update of position (drift)
|
# Computes the update of position (drift)
|
||||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||||
|
@ -188,20 +195,24 @@ def pgd_correction(pos, mesh_shape, params):
|
||||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||||
alpha, kl, ks = params
|
alpha, kl, ks = params
|
||||||
delta_k = jnp.fft.rfftn(delta)
|
delta_k = jnp.fft.rfftn(delta)
|
||||||
PGD_range=PGD_kernel(kvec, kl, ks)
|
PGD_range = PGD_kernel(kvec, kl, ks)
|
||||||
|
|
||||||
pot_k_pgd=(delta_k * invlaplace_kernel(kvec))*PGD_range
|
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
|
||||||
|
|
||||||
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k_pgd), pos)
|
forces_pgd = jnp.stack([
|
||||||
for i in range(3)],axis=-1)
|
cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
||||||
|
for i in range(3)
|
||||||
|
],
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
dpos_pgd = forces_pgd*alpha
|
dpos_pgd = forces_pgd * alpha
|
||||||
|
|
||||||
return dpos_pgd
|
return dpos_pgd
|
||||||
|
|
||||||
|
|
||||||
def make_neural_ode_fn(model, mesh_shape):
|
def make_neural_ode_fn(model, mesh_shape):
|
||||||
def neural_nbody_ode(state, a, cosmo:Cosmology, params):
|
|
||||||
|
def neural_nbody_ode(state, a, cosmo: Cosmology, params):
|
||||||
"""
|
"""
|
||||||
state is a tuple (position, velocities)
|
state is a tuple (position, velocities)
|
||||||
"""
|
"""
|
||||||
|
@ -213,15 +224,19 @@ def make_neural_ode_fn(model, mesh_shape):
|
||||||
delta_k = jnp.fft.rfftn(delta)
|
delta_k = jnp.fft.rfftn(delta)
|
||||||
|
|
||||||
# Computes gravitational potential
|
# Computes gravitational potential
|
||||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
|
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
|
||||||
|
r_split=0)
|
||||||
|
|
||||||
# Apply a correction filter
|
# Apply a correction filter
|
||||||
kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec))
|
kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
|
||||||
pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
||||||
|
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
forces = jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k), pos)
|
forces = jnp.stack([
|
||||||
for i in range(3)],axis=-1)
|
cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k), pos)
|
||||||
|
for i in range(3)
|
||||||
|
],
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
forces = forces * 1.5 * cosmo.Omega_m
|
forces = forces * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
|
@ -232,4 +247,5 @@ def make_neural_ode_fn(model, mesh_shape):
|
||||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
return dpos, dvel
|
return dpos, dvel
|
||||||
|
|
||||||
return neural_nbody_ode
|
return neural_nbody_ode
|
||||||
|
|
Loading…
Add table
Reference in a new issue