Merge branch 'main' into neural_ode

This commit is contained in:
Francois Lanusse 2024-07-19 10:48:09 -04:00 committed by GitHub
commit 9a279d2d6c
16 changed files with 858 additions and 328 deletions

View file

View file

@ -0,0 +1,292 @@
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.experimental.pjit import PartitionSpec, pjit
import jaxpm.painting as paint
# TODO: add a way to configure axis resources from command line
axis_resources = {'x': 'nx', 'y': 'ny'}
mesh_size = {'nx': 2, 'ny': 2}
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def stack3d(a, b, c):
return jnp.stack([a, b, c], axis=-1)
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, [...]),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def scalar_multiply(a, factor):
return a * factor
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def add(a, b):
return a + b
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def fft3d(mesh):
""" Performs a 3D complex Fourier transform
Args:
mesh: a real 3D tensor of shape [Nx, Ny, Nz]
Returns:
3D FFT of the input, note that the dimensions of the output
are tranposed.
"""
mesh = jnp.fft.fft(mesh)
mesh = lax.all_to_all(mesh, 'x', 0, 0)
mesh = jnp.fft.fft(mesh)
mesh = lax.all_to_all(mesh, 'y', 0, 0)
return jnp.fft.fft(mesh) # Note the output is transposed # [z, x, y]
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def ifft3d(mesh):
mesh = jnp.fft.ifft(mesh)
mesh = lax.all_to_all(mesh, 'y', 0, 0)
mesh = jnp.fft.ifft(mesh)
mesh = lax.all_to_all(mesh, 'x', 0, 0)
return jnp.fft.ifft(mesh).real
def normal(key, shape=[]):
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes={
0: 'x',
2: 'y'
},
axis_resources=axis_resources)
def fn(key):
""" Generate a distributed random normal distributions
Args:
key: array of random keys with same layout as computational mesh
shape: logical shape of array to sample
"""
return jax.random.normal(
key,
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
shape[2:])
return fn(key)
@partial(xmap,
in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
@jax.jit
def scale_by_power_spectrum(kfield, kvec, k, pk):
kx, ky, kz = kvec
kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
@partial(xmap,
in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
out_axes=(['x', 'y', 'z']),
axis_resources=axis_resources)
def gradient_laplace_kernel(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1. / kk)
return (kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
@partial(xmap,
in_axes=([...]),
out_axes={
0: 'x',
2: 'y'
},
axis_sizes={
'x': mesh_size['nx'],
'y': mesh_size['ny']
},
axis_resources=axis_resources)
def meshgrid(x, y, z):
""" Generates a mesh grid of appropriate size for the
computational mesh we have.
"""
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
def cic_paint(pos, mesh_shape, halo_size=0):
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(pos):
mesh = jnp.zeros([
mesh_shape[0] // mesh_size['nx'] +
2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
] + mesh_shape[2:])
# Paint particles
mesh = paint.cic_paint(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
# Perform halo exchange
# Halo exchange along x
left = lax.pshuffle(mesh[-2 * halo_size:],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
right = lax.pshuffle(mesh[:2 * halo_size],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
mesh = mesh.at[:2 * halo_size].add(left)
mesh = mesh.at[-2 * halo_size:].add(right)
# Halo exchange along y
left = lax.pshuffle(mesh[:, -2 * halo_size:],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
right = lax.pshuffle(mesh[:, :2 * halo_size],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
mesh = mesh.at[:, :2 * halo_size].add(left)
mesh = mesh.at[:, -2 * halo_size:].add(right)
# removing halo and returning mesh
return mesh[halo_size:-halo_size, halo_size:-halo_size]
return fn(pos)
def cic_read(mesh, pos, halo_size=0):
@partial(xmap,
in_axes=(
{
0: 'x',
2: 'y'
},
{
0: 'x',
2: 'y'
},
),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(mesh, pos):
# Halo exchange to grab neighboring borders
# Exchange along x
left = lax.pshuffle(mesh[-halo_size:],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
right = lax.pshuffle(mesh[:halo_size],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
mesh = jnp.concatenate([left, mesh, right], axis=0)
# Exchange along y
left = lax.pshuffle(mesh[:, -halo_size:],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
right = lax.pshuffle(mesh[:, :halo_size],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
mesh = jnp.concatenate([left, mesh, right], axis=1)
# Reading field at particles positions
res = paint.cic_read(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
return res.reshape(pos.shape[:-1])
return fn(mesh, pos)
@partial(pjit,
in_axis_resources=PartitionSpec('nx', 'ny'),
out_axis_resources=PartitionSpec('nx', None, 'ny', None))
def reshape_dense_to_split(x):
""" Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([
mesh_size['nx'], shape[0] //
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
] + shape[2:])
@partial(pjit,
in_axis_resources=PartitionSpec('nx', None, 'ny', None),
out_axis_resources=PartitionSpec('nx', 'ny'))
def reshape_split_to_dense(x):
""" Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])

View file

@ -0,0 +1,100 @@
from functools import partial
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.lax import linear_solve_p
import jaxpm.experimental.distributed_ops as dops
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import fftk
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
"""
Computes gravitational forces on particles using a PM scheme
"""
if mesh_shape is None:
mesh_shape = delta_k.shape
kvec = [k.squeeze() for k in fftk(mesh_shape, symmetric=False)]
if delta_k is None:
delta = dops.cic_paint(positions, mesh_shape, halo_size)
delta_k = dops.fft3d(dops.reshape_split_to_dense(delta))
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
# Recovers forces at particle positions
forces = [
dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
halo_size) for f in forces_k
]
return dops.stack3d(*forces)
def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
"""
Generate initial conditions.
Seed should have the dimension of the computational mesh
"""
# Sample normal field
field = dops.normal(seed, shape=mesh_shape)
# Go to Fourier space
field = dops.fft3d(dops.reshape_split_to_dense(field))
# Rescaling k to physical units
kvec = [
k.squeeze() / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
]
k = jnp.logspace(-4, 2, 256)
pk = jc.power.linear_matter_power(cosmo, k)
pk = pk * (mesh_shape[0] * mesh_shape[1] *
mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
if return_Fourier:
return field
else:
return dops.reshape_dense_to_split(dops.ifft3d(field))
def lpt(cosmo, initial_conditions, positions, a):
"""
Computes first order LPT displacement
"""
initial_force = pm_forces(positions, delta_k=initial_conditions)
a = jnp.atleast_1d(a)
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
p = dops.scalar_multiply(
dx,
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
return dx, p
def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = dops.scalar_multiply(
vel, 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
# Computes the update of velocity (kick)
dvel = dops.scalar_multiply(
forces, 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
return dpos, dvel
return nbody_ode

View file

@ -1,8 +1,8 @@
import jax.numpy as np
from jax_cosmo.background import *
from jax_cosmo.scipy.interpolate import interp
from jax_cosmo.scipy.ode import odeint
from jax_cosmo.background import *
def E(cosmo, a):
r"""Scale factor dependent factor E(a) in the Hubble
@ -52,12 +52,8 @@ def df_de(cosmo, a, epsilon=1e-5):
\frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)-
\frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)}
"""
return (
3
* cosmo.wa
* (np.log(a - epsilon) - (a - 1) / (a - epsilon))
/ np.power(np.log(a - epsilon), 2)
)
return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) /
np.power(np.log(a - epsilon), 2))
def dEa(cosmo, a):
@ -89,15 +85,11 @@ def dEa(cosmo, a):
where :math:`f(a)` is the Dark Energy evolution parameter computed
by :py:meth:`.f_de`.
"""
return (
0.5
* (
-3 * cosmo.Omega_m * np.power(a, -4)
- 2 * cosmo.Omega_k * np.power(a, -3)
+ df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))
)
/ np.power(Esqr(cosmo, a), 0.5)
)
return (0.5 *
(-3 * cosmo.Omega_m * np.power(a, -4) -
2 * cosmo.Omega_k * np.power(a, -3) +
df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) /
np.power(Esqr(cosmo, a), 0.5))
def growth_factor(cosmo, a):
@ -155,8 +147,7 @@ def growth_factor_second(cosmo, a):
"""
if cosmo._flags["gamma_growth"]:
raise NotImplementedError(
"Gamma growth rate is not implemented for second order growth!"
)
"Gamma growth rate is not implemented for second order growth!")
return None
else:
return _growth_factor_second_ODE(cosmo, a)
@ -228,8 +219,7 @@ def growth_rate_second(cosmo, a):
"""
if cosmo._flags["gamma_growth"]:
raise NotImplementedError(
"Gamma growth factor is not implemented for second order growth!"
)
"Gamma growth factor is not implemented for second order growth!")
return None
else:
return _growth_rate_second_ODE(cosmo, a)
@ -258,23 +248,19 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
atab = np.logspace(log10_amin, 0.0, steps)
def D_derivs(y, x):
q = (
2.0
- 0.5
* (
Omega_m_a(cosmo, x)
+ (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x)
)
) / x
q = (2.0 - 0.5 *
(Omega_m_a(cosmo, x) +
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
r = 1.5 * Omega_m_a(cosmo, x) / x / x
g1, g2 = y[0]
f1, f2 = y[1]
dy1da = [f1, -q * f1 + r * g1]
dy2da = [f2, -q * f2 + r * g2 - r * g1 ** 2]
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
[1.0, -6.0 / 7 * atab[0]]])
y = odeint(D_derivs, y0, atab)
# compute second order derivatives growth
@ -473,8 +459,7 @@ def _growth_rate_gamma(cosmo, a):
see :cite:`2019:Euclid Preparation VII, eqn.32`
"""
return Omega_m_a(cosmo, a) ** cosmo.gamma
return Omega_m_a(cosmo, a)**cosmo.gamma
def Gf(cosmo, a):
@ -503,7 +488,7 @@ def Gf(cosmo, a):
"""
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1*g1/ a
D1f = f1 * g1 / a
return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
@ -532,7 +517,7 @@ def Gf2(cosmo, a):
"""
f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a)
D2f = f2*g2/ a
D2f = f2 * g2 / a
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
@ -563,13 +548,12 @@ def dGfa(cosmo, a):
"""
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1*g1/ a
D1f = f1 * g1 / a
cache = cosmo._workspace['background.growth_factor']
f1p = cache['h'] / cache['a'] * cache['g']
f1p = interp(np.log(a), np.log(cache['a']), f1p)
Ea = E(cosmo, a)
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) +
3 * a**2 * Ea * D1f)
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f)
def dGf2a(cosmo, a):
@ -599,10 +583,9 @@ def dGf2a(cosmo, a):
"""
f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a)
D2f = f2*g2/ a
D2f = f2 * g2 / a
cache = cosmo._workspace['background.growth_factor']
f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p)
E = E(cosmo, a)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) +
3 * a**2 * E * D2f)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)

View file

@ -1,25 +1,27 @@
import numpy as np
import jax.numpy as jnp
import numpy as np
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
""" Return k_vector given a shape (nc, nc, nc) and box_size
""" Return k_vector given a shape (nc, nc, nc) and box_size
"""
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
k.append(kd.astype(dtype))
del kd, kdshape
return k
k.append(kd.astype(dtype))
del kd, kdshape
return k
def gradient_kernel(kvec, direction, order=1):
"""
"""
Computes the gradient kernel in the requested direction
Parameters:
-----------
@ -32,20 +34,21 @@ def gradient_kernel(kvec, direction, order=1):
wts: array
Complex kernel
"""
if order == 0:
wts = 1j * kvec[direction]
wts = jnp.squeeze(wts)
wts[len(wts) // 2] = 0
wts = wts.reshape(kvec[direction].shape)
return wts
else:
w = kvec[direction]
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
wts = a * 1j
return wts
if order == 0:
wts = 1j * kvec[direction]
wts = jnp.squeeze(wts)
wts[len(wts) // 2] = 0
wts = wts.reshape(kvec[direction].shape)
return wts
else:
w = kvec[direction]
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
wts = a * 1j
return wts
def laplace_kernel(kvec):
"""
"""
Compute the Laplace kernel from a given K vector
Parameters:
-----------
@ -56,16 +59,17 @@ def laplace_kernel(kvec):
wts: array
Complex kernel
"""
kk = sum(ki**2 for ki in kvec)
mask = (kk == 0).nonzero()
kk[mask] = 1
wts = 1. / kk
imask = (~(kk == 0)).astype(int)
wts *= imask
return wts
kk = sum(ki**2 for ki in kvec)
mask = (kk == 0).nonzero()
kk[mask] = 1
wts = 1. / kk
imask = (~(kk == 0)).astype(int)
wts *= imask
return wts
def longrange_kernel(kvec, r_split):
"""
"""
Computes a long range kernel
Parameters:
-----------
@ -78,29 +82,31 @@ def longrange_kernel(kvec, r_split):
wts: array
kernel
"""
if r_split != 0:
kk = sum(ki**2 for ki in kvec)
return np.exp(-kk * r_split**2)
else:
return 1.
if r_split != 0:
kk = sum(ki**2 for ki in kvec)
return np.exp(-kk * r_split**2)
else:
return 1.
def cic_compensation(kvec):
"""
"""
Computes cic compensation kernel.
Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499
Itself based on equation 18 (with p=2) of
`Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_
Args:
kvec: array of k values in Fourier space
kvec: array of k values in Fourier space
Returns:
v: array of kernel
"""
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts
def PGD_kernel(kvec, kl, ks):
"""
"""
Computes the PGD kernel
Parameters:
-----------
@ -115,12 +121,12 @@ def PGD_kernel(kvec, kl, ks):
v: array
kernel
"""
kk = sum(ki**2 for ki in kvec)
kl2 = kl**2
ks4 = ks**4
mask = (kk == 0).nonzero()
kk[mask] = 1
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
imask = (~(kk == 0)).astype(int)
v *= imask
return v
kk = sum(ki**2 for ki in kvec)
kl2 = kl**2
ks4 = ks**4
mask = (kk == 0).nonzero()
kk[mask] = 1
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
imask = (~(kk == 0)).astype(int)
v *= imask
return v

View file

@ -1,11 +1,12 @@
import jax
import jax
import jax.numpy as jnp
import jax_cosmo.constants as constants
import jax_cosmo
import jax_cosmo.constants as constants
from jax.scipy.ndimage import map_coordinates
from jaxpm.utils import gaussian_smoothing
from jaxpm.painting import cic_paint_2d
from jaxpm.utils import gaussian_smoothing
def density_plane(positions,
box_shape,
@ -26,9 +27,11 @@ def density_plane(positions,
xy = xy / nx * plane_resolution
# Selecting only particles that fall inside the volume of interest
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
weight = jnp.where(
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
# Painting density plane
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
density_plane = cic_paint_2d(
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
# Apply density normalization
density_plane = density_plane / ((nx / plane_resolution) *
@ -36,45 +39,44 @@ def density_plane(positions,
# Apply Gaussian smoothing if requested
if smoothing_sigma is not None:
density_plane = gaussian_smoothing(density_plane,
smoothing_sigma)
density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
return density_plane
def convergence_Born(cosmo,
density_planes,
coords,
z_source):
"""
def convergence_Born(cosmo, density_planes, coords, z_source):
"""
Compute the Born convergence
Args:
cosmo: `Cosmology`, cosmology object.
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
name: `string`, name of the operation.
Returns:
`Tensor` of shape [batch_size, N, Nz], of convergence values.
"""
# Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
# Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(
cosmo, 1 / (1 + z_source))
convergence = 0
for entry in density_planes:
r = entry['r']; a = entry['a']; p = entry['plane']
dx = entry['dx']; dz = entry['dz']
# Normalize density planes
density_normalization = dz * r / a
p = (p - p.mean()) * constant_factor * density_normalization
convergence = 0
for entry in density_planes:
r = entry['r']
a = entry['a']
p = entry['plane']
dx = entry['dx']
dz = entry['dz']
# Normalize density planes
density_normalization = dz * r / a
p = (p - p.mean()) * constant_factor * density_normalization
# Interpolate at the density plane coordinates
im = map_coordinates(p,
coords * r / dx - 0.5,
order=1, mode="wrap")
# Interpolate at the density plane coordinates
im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1])
convergence += im * jnp.clip(1. -
(r / r_s), 0, 1000).reshape([-1, 1, 1])
return convergence
return convergence

View file

@ -1,6 +1,7 @@
import haiku as hk
import jax
import jax.numpy as jnp
import haiku as hk
def _deBoorVectorized(x, t, c, p):
"""
@ -13,48 +14,47 @@ def _deBoorVectorized(x, t, c, p):
c: array of control points
p: degree of B-spline
"""
k = jnp.digitize(x, t) -1
d = [c[j + k - p] for j in range(0, p+1)]
for r in range(1, p+1):
for j in range(p, r-1, -1):
alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p])
d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j]
k = jnp.digitize(x, t) - 1
d = [c[j + k - p] for j in range(0, p + 1)]
for r in range(1, p + 1):
for j in range(p, r - 1, -1):
alpha = (x - t[j + k - p]) / (t[j + 1 + k - r] - t[j + k - p])
d[j] = (1.0 - alpha) * d[j - 1] + alpha * d[j]
return d[p]
class NeuralSplineFourierFilter(hk.Module):
"""A rotationally invariant filter parameterized by
"""A rotationally invariant filter parameterized by
a b-spline with parameters specified by a small NN."""
def __init__(self, n_knots=8, latent_size=16, name=None):
def __init__(self, n_knots=8, latent_size=16, name=None):
"""
n_knots: number of control points for the spline
"""
n_knots: number of control points for the spline
"""
super().__init__(name=name)
self.n_knots = n_knots
self.latent_size = latent_size
super().__init__(name=name)
self.n_knots = n_knots
self.latent_size = latent_size
def __call__(self, x, a):
"""
def __call__(self, x, a):
"""
x: array, scale, normalized to fftfreq default
a: scalar, scale factor
"""
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
net = jnp.sin(hk.Linear(self.latent_size)(net))
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
net = jnp.sin(hk.Linear(self.latent_size)(net))
w = hk.Linear(self.n_knots+1)(net)
k = hk.Linear(self.n_knots-1)(net)
# make sure the knots sum to 1 and are in the interval 0,1
k = jnp.concatenate([jnp.zeros((1,)),
jnp.cumsum(jax.nn.softmax(k))])
w = hk.Linear(self.n_knots + 1)(net)
k = hk.Linear(self.n_knots - 1)(net)
w = jnp.concatenate([jnp.zeros((1,)),
w])
# make sure the knots sum to 1 and are in the interval 0,1
k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))])
# Augment with repeating points
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
w = jnp.concatenate([jnp.zeros((1, )), w])
return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
# Augment with repeating points
ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))])
return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w,
3)

View file

@ -1,96 +1,100 @@
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax.numpy as jnp
from jaxpm.kernels import fftk, cic_compensation
from jaxpm.kernels import cic_compensation, fftk
def cic_paint(mesh, positions):
""" Paints positions onto mesh
def cic_paint(mesh, positions, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
dnums)
return mesh
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,8]),
dnums)
return mesh
def cic_read(mesh, positions):
""" Paints positions onto mesh
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape))
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
return (mesh[neighboor_coords[...,0],
neighboor_coords[...,1],
neighboor_coords[...,3]]*kernel).sum(axis=-1)
def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh
""" Paints positions onto a 2d mesh
mesh: [nx, ny]
positions: [npart, 2]
weight: [npart]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[...,jnp.newaxis]
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[..., jnp.newaxis]
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0,
1))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
dnums)
return mesh
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0, 1))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,4]),
dnums)
return mesh
def compensate_cic(field):
"""
"""
Compensate for CiC painting
Args:
field: input 3D cic-painted field
Returns:
compensated_field
"""
nc = field.shape
kvec = fftk(nc)
nc = field.shape
kvec = fftk(nc)
delta_k = jnp.fft.rfftn(field)
delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k)
delta_k = jnp.fft.rfftn(field)
delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k)

View file

@ -1,11 +1,12 @@
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel, PGD_kernel
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
longrange_kernel)
from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
"""
@ -21,10 +22,14 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
delta_k = jnp.fft.rfftn(delta)
# Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
r_split=r_split)
# Computes gravitational forces
return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions)
for i in range(3)],axis=-1)
return jnp.stack([
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions)
for i in range(3)
],
axis=-1)
def lpt(cosmo, initial_conditions, positions, a):
@ -34,25 +39,31 @@ def lpt(cosmo, initial_conditions, positions, a):
initial_force = pm_forces(positions, delta=initial_conditions)
a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo,
a) * initial_force
return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed):
"""
Generate initial conditions.
"""
kvec = fftk(mesh_shape)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape)
field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field)
return field
def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo):
"""
state is a tuple (position, velocities)
@ -63,10 +74,10 @@ def make_ode_fn(mesh_shape):
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel
return nbody_ode
@ -128,4 +139,3 @@ def make_neural_ode_fn(model, mesh_shape):
return dpos, dvel
return neural_nbody_ode

View file

@ -1,85 +1,87 @@
import numpy as np
import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm
__all__ = ['power_spectrum']
def _initialize_pk(shape, boxsize, kmin, dk):
"""
"""
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
"""
I = np.eye(len(shape), dtype='int') * -2 + 1
I = np.eye(len(shape), dtype='int') * -2 + 1
W = np.empty(shape, dtype='f4')
W[...] = 2.0
W[..., 0] = 1.0
W[..., -1] = 1.0
W = np.empty(shape, dtype='f4')
W[...] = 2.0
W[..., 0] = 1.0
W[..., -1] = 1.0
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
kedges = np.arange(kmin, kmax, dk)
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
kedges = np.arange(kmin, kmax, dk)
k = [
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
]
kmag = sum(ki**2 for ki in k)**0.5
k = [
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
]
kmag = sum(ki**2 for ki in k)**0.5
xsum = np.zeros(len(kedges) + 1)
Nsum = np.zeros(len(kedges) + 1)
xsum = np.zeros(len(kedges) + 1)
Nsum = np.zeros(len(kedges) + 1)
dig = np.digitize(kmag.flat, kedges)
dig = np.digitize(kmag.flat, kedges)
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
"""
"""
Calculate the powerspectra given real space field
Args:
field: real valued field
field: real valued field
kmin: minimum k-value for binned powerspectra
dk: differential in each kbin
boxsize: length of each boxlength (can be strangly shaped?)
Returns:
kbins: the central value of the bins for plotting
power: real valued array of power in each bin
"""
shape = field.shape
nx, ny, nz = shape
shape = field.shape
nx, ny, nz = shape
#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#fast fourier transform
fft_image = jnp.fft.fftn(field)
#fast fourier transform
fft_image = jnp.fft.fftn(field)
#absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image))
#absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image))
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#find central values of each bin
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
#find central values of each bin
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
return kbins, P / norm
return kbins, P / norm
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
"""
@ -131,18 +133,17 @@ def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=Fals
def gaussian_smoothing(im, sigma):
"""
"""
im: 2d image
sigma: smoothing scale in px
sigma: smoothing scale in px
"""
# Compute k vector
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
jnp.fft.fftfreq(im.shape[1])),
axis=-1)
k = jnp.linalg.norm(kvec, axis=-1)
# We compute the value of the filter at frequency k
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
filter /= filter[0,0]
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
# Compute k vector
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
jnp.fft.fftfreq(im.shape[1])),
axis=-1)
k = jnp.linalg.norm(kvec, axis=-1)
# We compute the value of the filter at frequency k
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
filter /= filter[0, 0]
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real