mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-16 04:41:11 +00:00
Applying formatting
This commit is contained in:
parent
5f463450d1
commit
a2811c0606
15 changed files with 566 additions and 446 deletions
|
@ -1,13 +1,14 @@
|
|||
import jax
|
||||
from jax.lax import linear_solve_p
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.maps import xmap
|
||||
from functools import partial
|
||||
import jax_cosmo as jc
|
||||
|
||||
from jaxpm.kernels import fftk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.lax import linear_solve_p
|
||||
|
||||
import jaxpm.experimental.distributed_ops as dops
|
||||
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
||||
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
||||
from jaxpm.kernels import fftk
|
||||
|
||||
|
||||
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
|
||||
|
@ -25,8 +26,10 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
|
|||
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
|
||||
|
||||
# Recovers forces at particle positions
|
||||
forces = [dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)),
|
||||
positions, halo_size) for f in forces_k]
|
||||
forces = [
|
||||
dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
|
||||
halo_size) for f in forces_k
|
||||
]
|
||||
|
||||
return dops.stack3d(*forces)
|
||||
|
||||
|
@ -44,12 +47,14 @@ def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
|
|||
field = dops.fft3d(dops.reshape_split_to_dense(field))
|
||||
|
||||
# Rescaling k to physical units
|
||||
kvec = [k.squeeze() / box_size[i] * mesh_shape[i]
|
||||
for i, k in enumerate(fftk(mesh_shape, symmetric=False))]
|
||||
kvec = [
|
||||
k.squeeze() / box_size[i] * mesh_shape[i]
|
||||
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
|
||||
]
|
||||
k = jnp.logspace(-4, 2, 256)
|
||||
pk = jc.power.linear_matter_power(cosmo, k)
|
||||
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
|
||||
) / (box_size[0] * box_size[1] * box_size[2])
|
||||
pk = pk * (mesh_shape[0] * mesh_shape[1] *
|
||||
mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
|
||||
|
||||
|
@ -66,8 +71,9 @@ def lpt(cosmo, initial_conditions, positions, a):
|
|||
initial_force = pm_forces(positions, delta_k=initial_conditions)
|
||||
a = jnp.atleast_1d(a)
|
||||
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
|
||||
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
|
||||
jnp.sqrt(jc.background.Esqr(cosmo, a)))
|
||||
p = dops.scalar_multiply(
|
||||
dx,
|
||||
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
|
||||
return dx, p
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue