Merge branch 'main' into neural_ode

This commit is contained in:
Francois Lanusse 2024-07-19 10:48:09 -04:00 committed by GitHub
commit 9a279d2d6c
16 changed files with 858 additions and 328 deletions

17
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,17 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
args: ['--parallel', '--in-place']
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)

14
dev/job_pfft.sh Normal file
View file

@ -0,0 +1,14 @@
#!/bin/bash
#SBATCH -A m1727
#SBATCH -C gpu
#SBATCH -q debug
#SBATCH -t 0:05:00
#SBATCH -N 2
#SBATCH --ntasks-per-node=4
#SBATCH -c 32
#SBATCH --gpus-per-task=1
#SBATCH --gpu-bind=none
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
export SLURM_CPU_BIND="cores"
srun python test_pfft.py

96
dev/test_pfft.py Normal file
View file

@ -0,0 +1,96 @@
# Can be executed with:
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
from jax.experimental.maps import Mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
jax.distributed.initialize()
cube_size = 2048
@partial(xmap,
in_axes=[...],
out_axes=['x', 'y', ...],
axis_sizes={
'x': cube_size,
'y': cube_size
},
axis_resources={
'x': 'nx',
'y': 'ny',
'key_x': 'nx',
'key_y': 'ny'
})
def pnormal(key):
return jax.random.normal(key, shape=[cube_size])
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pfft3d(mesh):
# [x, y, z]
mesh = jnp.fft.fft(mesh) # Transform on z
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.fft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
mesh = jnp.fft.fft(mesh) # Transform on y
# [z, x, y]
return mesh
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pifft3d(mesh):
# [z, x, y]
mesh = jnp.fft.ifft(mesh) # Transform on y
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.ifft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
mesh = jnp.fft.ifft(mesh) # Transform on z
# [x, y, z]
return mesh
key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape((2,2,2))
# We reshape all our devices to the mesh shape we want
devices = np.array(jax.devices()).reshape((2, 4))
with Mesh(devices, ('nx', 'ny')):
mesh = pnormal(key)
kmesh = pfft3d(mesh)
kmesh.block_until_ready()
# jax.profiler.start_trace("tensorboard")
# with Mesh(devices, ('nx', 'ny')):
# mesh = pnormal(key)
# kmesh = pfft3d(mesh)
# kmesh.block_until_ready()
# jax.profiler.stop_trace()
print('Done')

View file

@ -1,17 +1,21 @@
# Start this script with:
# mpirun -np 4 python test_script.py
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
import matplotlib.pylab as plt
import jax
import numpy as np
import jax.numpy as jnp
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as np
import tensorflow_probability as tfp
from jax.experimental.maps import mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfp = tfp.substrates.jax
tfd = tfp.distributions
def cic_paint(mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
@ -19,26 +23,27 @@ def cic_paint(mesh, positions):
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(
mesh,
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
kernel.reshape([-1,8]),
dnums)
kernel.reshape([-1, 8]), dnums)
return mesh
# And let's draw some points from some 3D distribution
dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.)
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
scale_identity_multiplier=3.)
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
f = pjit(lambda x: cic_paint(x, pos),

View file

View file

@ -0,0 +1,292 @@
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.experimental.pjit import PartitionSpec, pjit
import jaxpm.painting as paint
# TODO: add a way to configure axis resources from command line
axis_resources = {'x': 'nx', 'y': 'ny'}
mesh_size = {'nx': 2, 'ny': 2}
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def stack3d(a, b, c):
return jnp.stack([a, b, c], axis=-1)
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, [...]),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def scalar_multiply(a, factor):
return a * factor
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def add(a, b):
return a + b
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def fft3d(mesh):
""" Performs a 3D complex Fourier transform
Args:
mesh: a real 3D tensor of shape [Nx, Ny, Nz]
Returns:
3D FFT of the input, note that the dimensions of the output
are tranposed.
"""
mesh = jnp.fft.fft(mesh)
mesh = lax.all_to_all(mesh, 'x', 0, 0)
mesh = jnp.fft.fft(mesh)
mesh = lax.all_to_all(mesh, 'y', 0, 0)
return jnp.fft.fft(mesh) # Note the output is transposed # [z, x, y]
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def ifft3d(mesh):
mesh = jnp.fft.ifft(mesh)
mesh = lax.all_to_all(mesh, 'y', 0, 0)
mesh = jnp.fft.ifft(mesh)
mesh = lax.all_to_all(mesh, 'x', 0, 0)
return jnp.fft.ifft(mesh).real
def normal(key, shape=[]):
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes={
0: 'x',
2: 'y'
},
axis_resources=axis_resources)
def fn(key):
""" Generate a distributed random normal distributions
Args:
key: array of random keys with same layout as computational mesh
shape: logical shape of array to sample
"""
return jax.random.normal(
key,
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
shape[2:])
return fn(key)
@partial(xmap,
in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
@jax.jit
def scale_by_power_spectrum(kfield, kvec, k, pk):
kx, ky, kz = kvec
kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
@partial(xmap,
in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
out_axes=(['x', 'y', 'z']),
axis_resources=axis_resources)
def gradient_laplace_kernel(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1. / kk)
return (kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
@partial(xmap,
in_axes=([...]),
out_axes={
0: 'x',
2: 'y'
},
axis_sizes={
'x': mesh_size['nx'],
'y': mesh_size['ny']
},
axis_resources=axis_resources)
def meshgrid(x, y, z):
""" Generates a mesh grid of appropriate size for the
computational mesh we have.
"""
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
def cic_paint(pos, mesh_shape, halo_size=0):
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(pos):
mesh = jnp.zeros([
mesh_shape[0] // mesh_size['nx'] +
2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
] + mesh_shape[2:])
# Paint particles
mesh = paint.cic_paint(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
# Perform halo exchange
# Halo exchange along x
left = lax.pshuffle(mesh[-2 * halo_size:],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
right = lax.pshuffle(mesh[:2 * halo_size],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
mesh = mesh.at[:2 * halo_size].add(left)
mesh = mesh.at[-2 * halo_size:].add(right)
# Halo exchange along y
left = lax.pshuffle(mesh[:, -2 * halo_size:],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
right = lax.pshuffle(mesh[:, :2 * halo_size],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
mesh = mesh.at[:, :2 * halo_size].add(left)
mesh = mesh.at[:, -2 * halo_size:].add(right)
# removing halo and returning mesh
return mesh[halo_size:-halo_size, halo_size:-halo_size]
return fn(pos)
def cic_read(mesh, pos, halo_size=0):
@partial(xmap,
in_axes=(
{
0: 'x',
2: 'y'
},
{
0: 'x',
2: 'y'
},
),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(mesh, pos):
# Halo exchange to grab neighboring borders
# Exchange along x
left = lax.pshuffle(mesh[-halo_size:],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
right = lax.pshuffle(mesh[:halo_size],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
mesh = jnp.concatenate([left, mesh, right], axis=0)
# Exchange along y
left = lax.pshuffle(mesh[:, -halo_size:],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
right = lax.pshuffle(mesh[:, :halo_size],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
mesh = jnp.concatenate([left, mesh, right], axis=1)
# Reading field at particles positions
res = paint.cic_read(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
return res.reshape(pos.shape[:-1])
return fn(mesh, pos)
@partial(pjit,
in_axis_resources=PartitionSpec('nx', 'ny'),
out_axis_resources=PartitionSpec('nx', None, 'ny', None))
def reshape_dense_to_split(x):
""" Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([
mesh_size['nx'], shape[0] //
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
] + shape[2:])
@partial(pjit,
in_axis_resources=PartitionSpec('nx', None, 'ny', None),
out_axis_resources=PartitionSpec('nx', 'ny'))
def reshape_split_to_dense(x):
""" Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])

View file

@ -0,0 +1,100 @@
from functools import partial
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.lax import linear_solve_p
import jaxpm.experimental.distributed_ops as dops
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import fftk
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
"""
Computes gravitational forces on particles using a PM scheme
"""
if mesh_shape is None:
mesh_shape = delta_k.shape
kvec = [k.squeeze() for k in fftk(mesh_shape, symmetric=False)]
if delta_k is None:
delta = dops.cic_paint(positions, mesh_shape, halo_size)
delta_k = dops.fft3d(dops.reshape_split_to_dense(delta))
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
# Recovers forces at particle positions
forces = [
dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
halo_size) for f in forces_k
]
return dops.stack3d(*forces)
def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
"""
Generate initial conditions.
Seed should have the dimension of the computational mesh
"""
# Sample normal field
field = dops.normal(seed, shape=mesh_shape)
# Go to Fourier space
field = dops.fft3d(dops.reshape_split_to_dense(field))
# Rescaling k to physical units
kvec = [
k.squeeze() / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
]
k = jnp.logspace(-4, 2, 256)
pk = jc.power.linear_matter_power(cosmo, k)
pk = pk * (mesh_shape[0] * mesh_shape[1] *
mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
if return_Fourier:
return field
else:
return dops.reshape_dense_to_split(dops.ifft3d(field))
def lpt(cosmo, initial_conditions, positions, a):
"""
Computes first order LPT displacement
"""
initial_force = pm_forces(positions, delta_k=initial_conditions)
a = jnp.atleast_1d(a)
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
p = dops.scalar_multiply(
dx,
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
return dx, p
def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = dops.scalar_multiply(
vel, 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
# Computes the update of velocity (kick)
dvel = dops.scalar_multiply(
forces, 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
return dpos, dvel
return nbody_ode

View file

@ -1,8 +1,8 @@
import jax.numpy as np
from jax_cosmo.background import *
from jax_cosmo.scipy.interpolate import interp
from jax_cosmo.scipy.ode import odeint
from jax_cosmo.background import *
def E(cosmo, a):
r"""Scale factor dependent factor E(a) in the Hubble
@ -52,12 +52,8 @@ def df_de(cosmo, a, epsilon=1e-5):
\frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)-
\frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)}
"""
return (
3
* cosmo.wa
* (np.log(a - epsilon) - (a - 1) / (a - epsilon))
/ np.power(np.log(a - epsilon), 2)
)
return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) /
np.power(np.log(a - epsilon), 2))
def dEa(cosmo, a):
@ -89,15 +85,11 @@ def dEa(cosmo, a):
where :math:`f(a)` is the Dark Energy evolution parameter computed
by :py:meth:`.f_de`.
"""
return (
0.5
* (
-3 * cosmo.Omega_m * np.power(a, -4)
- 2 * cosmo.Omega_k * np.power(a, -3)
+ df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))
)
/ np.power(Esqr(cosmo, a), 0.5)
)
return (0.5 *
(-3 * cosmo.Omega_m * np.power(a, -4) -
2 * cosmo.Omega_k * np.power(a, -3) +
df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) /
np.power(Esqr(cosmo, a), 0.5))
def growth_factor(cosmo, a):
@ -155,8 +147,7 @@ def growth_factor_second(cosmo, a):
"""
if cosmo._flags["gamma_growth"]:
raise NotImplementedError(
"Gamma growth rate is not implemented for second order growth!"
)
"Gamma growth rate is not implemented for second order growth!")
return None
else:
return _growth_factor_second_ODE(cosmo, a)
@ -228,8 +219,7 @@ def growth_rate_second(cosmo, a):
"""
if cosmo._flags["gamma_growth"]:
raise NotImplementedError(
"Gamma growth factor is not implemented for second order growth!"
)
"Gamma growth factor is not implemented for second order growth!")
return None
else:
return _growth_rate_second_ODE(cosmo, a)
@ -258,14 +248,9 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
atab = np.logspace(log10_amin, 0.0, steps)
def D_derivs(y, x):
q = (
2.0
- 0.5
* (
Omega_m_a(cosmo, x)
+ (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x)
)
) / x
q = (2.0 - 0.5 *
(Omega_m_a(cosmo, x) +
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
r = 1.5 * Omega_m_a(cosmo, x) / x / x
g1, g2 = y[0]
@ -274,7 +259,8 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
[1.0, -6.0 / 7 * atab[0]]])
y = odeint(D_derivs, y0, atab)
# compute second order derivatives growth
@ -476,7 +462,6 @@ def _growth_rate_gamma(cosmo, a):
return Omega_m_a(cosmo, a)**cosmo.gamma
def Gf(cosmo, a):
r"""
FastPM growth factor function
@ -568,8 +553,7 @@ def dGfa(cosmo, a):
f1p = cache['h'] / cache['a'] * cache['g']
f1p = interp(np.log(a), np.log(cache['a']), f1p)
Ea = E(cosmo, a)
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) +
3 * a**2 * Ea * D1f)
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f)
def dGf2a(cosmo, a):
@ -604,5 +588,4 @@ def dGf2a(cosmo, a):
f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p)
E = E(cosmo, a)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) +
3 * a**2 * E * D2f)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)

View file

@ -1,5 +1,6 @@
import numpy as np
import jax.numpy as jnp
import numpy as np
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
""" Return k_vector given a shape (nc, nc, nc) and box_size
@ -18,6 +19,7 @@ def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
del kd, kdshape
return k
def gradient_kernel(kvec, direction, order=1):
"""
Computes the gradient kernel in the requested direction
@ -44,6 +46,7 @@ def gradient_kernel(kvec, direction, order=1):
wts = a * 1j
return wts
def laplace_kernel(kvec):
"""
Compute the Laplace kernel from a given K vector
@ -64,6 +67,7 @@ def laplace_kernel(kvec):
wts *= imask
return wts
def longrange_kernel(kvec, r_split):
"""
Computes a long range kernel
@ -84,6 +88,7 @@ def longrange_kernel(kvec, r_split):
else:
return 1.
def cic_compensation(kvec):
"""
Computes cic compensation kernel.
@ -99,6 +104,7 @@ def cic_compensation(kvec):
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts
def PGD_kernel(kvec, kl, ks):
"""
Computes the PGD kernel

View file

@ -1,11 +1,12 @@
import jax
import jax.numpy as jnp
import jax_cosmo.constants as constants
import jax_cosmo
import jax_cosmo.constants as constants
from jax.scipy.ndimage import map_coordinates
from jaxpm.utils import gaussian_smoothing
from jaxpm.painting import cic_paint_2d
from jaxpm.utils import gaussian_smoothing
def density_plane(positions,
box_shape,
@ -26,9 +27,11 @@ def density_plane(positions,
xy = xy / nx * plane_resolution
# Selecting only particles that fall inside the volume of interest
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
weight = jnp.where(
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
# Painting density plane
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
density_plane = cic_paint_2d(
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
# Apply density normalization
density_plane = density_plane / ((nx / plane_resolution) *
@ -36,16 +39,12 @@ def density_plane(positions,
# Apply Gaussian smoothing if requested
if smoothing_sigma is not None:
density_plane = gaussian_smoothing(density_plane,
smoothing_sigma)
density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
return density_plane
def convergence_Born(cosmo,
density_planes,
coords,
z_source):
def convergence_Born(cosmo, density_planes, coords, z_source):
"""
Compute the Born convergence
Args:
@ -60,21 +59,24 @@ def convergence_Born(cosmo,
# Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
r_s = jax_cosmo.background.radial_comoving_distance(
cosmo, 1 / (1 + z_source))
convergence = 0
for entry in density_planes:
r = entry['r']; a = entry['a']; p = entry['plane']
dx = entry['dx']; dz = entry['dz']
r = entry['r']
a = entry['a']
p = entry['plane']
dx = entry['dx']
dz = entry['dz']
# Normalize density planes
density_normalization = dz * r / a
p = (p - p.mean()) * constant_factor * density_normalization
# Interpolate at the density plane coordinates
im = map_coordinates(p,
coords * r / dx - 0.5,
order=1, mode="wrap")
im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1])
convergence += im * jnp.clip(1. -
(r / r_s), 0, 1000).reshape([-1, 1, 1])
return convergence

View file

@ -1,6 +1,7 @@
import haiku as hk
import jax
import jax.numpy as jnp
import haiku as hk
def _deBoorVectorized(x, t, c, p):
"""
@ -48,13 +49,12 @@ class NeuralSplineFourierFilter(hk.Module):
k = hk.Linear(self.n_knots - 1)(net)
# make sure the knots sum to 1 and are in the interval 0,1
k = jnp.concatenate([jnp.zeros((1,)),
jnp.cumsum(jax.nn.softmax(k))])
k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))])
w = jnp.concatenate([jnp.zeros((1,)),
w])
w = jnp.concatenate([jnp.zeros((1, )), w])
# Augment with repeating points
ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))])
return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w,
3)

View file

@ -1,36 +1,39 @@
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax.numpy as jnp
from jaxpm.kernels import fftk, cic_compensation
from jaxpm.kernels import cic_compensation, fftk
def cic_paint(mesh, positions):
def cic_paint(mesh, positions, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,8]),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
dnums)
return mesh
def cic_read(mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
@ -38,20 +41,20 @@ def cic_read(mesh, positions):
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape))
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(mesh.shape))
return (mesh[neighboor_coords[...,0],
neighboor_coords[...,1],
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh
mesh: [nx, ny]
@ -68,18 +71,19 @@ def cic_paint_2d(mesh, positions, weight):
if weight is not None:
kernel = kernel * weight[..., jnp.newaxis]
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0, 1))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,4]),
scatter_dims_to_operand_dims=(0,
1))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
dnums)
return mesh
def compensate_cic(field):
"""
Compensate for CiC painting

View file

@ -1,11 +1,12 @@
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel, PGD_kernel
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
longrange_kernel)
from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
"""
@ -21,10 +22,14 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
delta_k = jnp.fft.rfftn(delta)
# Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
r_split=r_split)
# Computes gravitational forces
return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions)
for i in range(3)],axis=-1)
return jnp.stack([
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions)
for i in range(3)
],
axis=-1)
def lpt(cosmo, initial_conditions, positions, a):
@ -34,23 +39,29 @@ def lpt(cosmo, initial_conditions, positions, a):
initial_force = pm_forces(positions, delta=initial_conditions)
a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo,
a) * initial_force
return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed):
"""
Generate initial conditions.
"""
kvec = fftk(mesh_shape)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape)
field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field)
return field
def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo):
@ -128,4 +139,3 @@ def make_neural_ode_fn(model, mesh_shape):
return dpos, dvel
return neural_nbody_ode

View file

@ -1,9 +1,10 @@
import numpy as np
import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm
__all__ = ['power_spectrum']
def _initialize_pk(shape, boxsize, kmin, dk):
"""
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
@ -63,12 +64,12 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
#absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image))
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
@ -81,6 +82,7 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
return kbins, P / norm
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
"""
Calculate the cross correlation coefficients given two real space field
@ -145,4 +147,3 @@ def gaussian_smoothing(im, sigma):
filter /= filter[0, 0]
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real

View file

@ -1,4 +1,4 @@
from setuptools import setup, find_packages
from setuptools import find_packages, setup
setup(
name='JaxPM',