Merge branch 'main' into neural_ode

This commit is contained in:
Francois Lanusse 2024-07-19 10:48:09 -04:00 committed by GitHub
commit 9a279d2d6c
16 changed files with 858 additions and 328 deletions

17
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,17 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
args: ['--parallel', '--in-place']
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)

View file

@ -4,14 +4,14 @@ This document aims to detail some of the API, implementation choices, and intern
## Objective ## Objective
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations. Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
## Related Work ## Related Work
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models. This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow - [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD - [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
- Borg - Borg

14
dev/job_pfft.sh Normal file
View file

@ -0,0 +1,14 @@
#!/bin/bash
#SBATCH -A m1727
#SBATCH -C gpu
#SBATCH -q debug
#SBATCH -t 0:05:00
#SBATCH -N 2
#SBATCH --ntasks-per-node=4
#SBATCH -c 32
#SBATCH --gpus-per-task=1
#SBATCH --gpu-bind=none
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
export SLURM_CPU_BIND="cores"
srun python test_pfft.py

96
dev/test_pfft.py Normal file
View file

@ -0,0 +1,96 @@
# Can be executed with:
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
from jax.experimental.maps import Mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
jax.distributed.initialize()
cube_size = 2048
@partial(xmap,
in_axes=[...],
out_axes=['x', 'y', ...],
axis_sizes={
'x': cube_size,
'y': cube_size
},
axis_resources={
'x': 'nx',
'y': 'ny',
'key_x': 'nx',
'key_y': 'ny'
})
def pnormal(key):
return jax.random.normal(key, shape=[cube_size])
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pfft3d(mesh):
# [x, y, z]
mesh = jnp.fft.fft(mesh) # Transform on z
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.fft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
mesh = jnp.fft.fft(mesh) # Transform on y
# [z, x, y]
return mesh
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pifft3d(mesh):
# [z, x, y]
mesh = jnp.fft.ifft(mesh) # Transform on y
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.ifft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
mesh = jnp.fft.ifft(mesh) # Transform on z
# [x, y, z]
return mesh
key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape((2,2,2))
# We reshape all our devices to the mesh shape we want
devices = np.array(jax.devices()).reshape((2, 4))
with Mesh(devices, ('nx', 'ny')):
mesh = pnormal(key)
kmesh = pfft3d(mesh)
kmesh.block_until_ready()
# jax.profiler.start_trace("tensorboard")
# with Mesh(devices, ('nx', 'ny')):
# mesh = pnormal(key)
# kmesh = pfft3d(mesh)
# kmesh.block_until_ready()
# jax.profiler.stop_trace()
print('Done')

View file

@ -1,48 +1,53 @@
# Start this script with: # Start this script with:
# mpirun -np 4 python test_script.py # mpirun -np 4 python test_script.py
import os import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
import matplotlib.pylab as plt import jax
import jax
import numpy as np
import jax.numpy as jnp
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as np
import tensorflow_probability as tfp
from jax.experimental.maps import mesh, xmap from jax.experimental.maps import mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit from jax.experimental.pjit import PartitionSpec, pjit
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfp = tfp.substrates.jax
tfd = tfp.distributions tfd = tfp.distributions
def cic_paint(mesh, positions): def cic_paint(mesh, positions):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
""" """
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[0., 0, 1], [1., 1, 0], [1., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
[0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(
mesh,
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
kernel.reshape([-1, 8]), dnums)
return mesh
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords.reshape([-1,8,3]).astype('int32'),
kernel.reshape([-1,8]),
dnums)
return mesh
# And let's draw some points from some 3D distribution # And let's draw some points from some 3D distribution
dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.) dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
scale_identity_multiplier=3.)
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0)) pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
f = pjit(lambda x: cic_paint(x, pos), f = pjit(lambda x: cic_paint(x, pos),
in_axis_resources=PartitionSpec('x', 'y', 'z'), in_axis_resources=PartitionSpec('x', 'y', 'z'),
out_axis_resources=None) out_axis_resources=None)
devices = np.array(jax.devices()).reshape((2, 2, 1)) devices = np.array(jax.devices()).reshape((2, 2, 1))
@ -51,13 +56,13 @@ devices = np.array(jax.devices()).reshape((2, 2, 1))
m = jnp.zeros([32, 32, 32]) m = jnp.zeros([32, 32, 32])
with mesh(devices, ('x', 'y', 'z')): with mesh(devices, ('x', 'y', 'z')):
# Shard the mesh, I'm not sure this is absolutely necessary # Shard the mesh, I'm not sure this is absolutely necessary
m = pjit(lambda x: x, m = pjit(lambda x: x,
in_axis_resources=None, in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m) out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
# Apply the sharded CiC function # Apply the sharded CiC function
res = f(m) res = f(m)
plt.imshow(res.sum(axis=2)) plt.imshow(res.sum(axis=2))
plt.show() plt.show()

View file

View file

@ -0,0 +1,292 @@
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.experimental.pjit import PartitionSpec, pjit
import jaxpm.painting as paint
# TODO: add a way to configure axis resources from command line
axis_resources = {'x': 'nx', 'y': 'ny'}
mesh_size = {'nx': 2, 'ny': 2}
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def stack3d(a, b, c):
return jnp.stack([a, b, c], axis=-1)
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, [...]),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def scalar_multiply(a, factor):
return a * factor
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def add(a, b):
return a + b
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def fft3d(mesh):
""" Performs a 3D complex Fourier transform
Args:
mesh: a real 3D tensor of shape [Nx, Ny, Nz]
Returns:
3D FFT of the input, note that the dimensions of the output
are tranposed.
"""
mesh = jnp.fft.fft(mesh)
mesh = lax.all_to_all(mesh, 'x', 0, 0)
mesh = jnp.fft.fft(mesh)
mesh = lax.all_to_all(mesh, 'y', 0, 0)
return jnp.fft.fft(mesh) # Note the output is transposed # [z, x, y]
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def ifft3d(mesh):
mesh = jnp.fft.ifft(mesh)
mesh = lax.all_to_all(mesh, 'y', 0, 0)
mesh = jnp.fft.ifft(mesh)
mesh = lax.all_to_all(mesh, 'x', 0, 0)
return jnp.fft.ifft(mesh).real
def normal(key, shape=[]):
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes={
0: 'x',
2: 'y'
},
axis_resources=axis_resources)
def fn(key):
""" Generate a distributed random normal distributions
Args:
key: array of random keys with same layout as computational mesh
shape: logical shape of array to sample
"""
return jax.random.normal(
key,
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
shape[2:])
return fn(key)
@partial(xmap,
in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
@jax.jit
def scale_by_power_spectrum(kfield, kvec, k, pk):
kx, ky, kz = kvec
kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
@partial(xmap,
in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
out_axes=(['x', 'y', 'z']),
axis_resources=axis_resources)
def gradient_laplace_kernel(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1. / kk)
return (kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
@partial(xmap,
in_axes=([...]),
out_axes={
0: 'x',
2: 'y'
},
axis_sizes={
'x': mesh_size['nx'],
'y': mesh_size['ny']
},
axis_resources=axis_resources)
def meshgrid(x, y, z):
""" Generates a mesh grid of appropriate size for the
computational mesh we have.
"""
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
def cic_paint(pos, mesh_shape, halo_size=0):
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(pos):
mesh = jnp.zeros([
mesh_shape[0] // mesh_size['nx'] +
2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
] + mesh_shape[2:])
# Paint particles
mesh = paint.cic_paint(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
# Perform halo exchange
# Halo exchange along x
left = lax.pshuffle(mesh[-2 * halo_size:],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
right = lax.pshuffle(mesh[:2 * halo_size],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
mesh = mesh.at[:2 * halo_size].add(left)
mesh = mesh.at[-2 * halo_size:].add(right)
# Halo exchange along y
left = lax.pshuffle(mesh[:, -2 * halo_size:],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
right = lax.pshuffle(mesh[:, :2 * halo_size],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
mesh = mesh.at[:, :2 * halo_size].add(left)
mesh = mesh.at[:, -2 * halo_size:].add(right)
# removing halo and returning mesh
return mesh[halo_size:-halo_size, halo_size:-halo_size]
return fn(pos)
def cic_read(mesh, pos, halo_size=0):
@partial(xmap,
in_axes=(
{
0: 'x',
2: 'y'
},
{
0: 'x',
2: 'y'
},
),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(mesh, pos):
# Halo exchange to grab neighboring borders
# Exchange along x
left = lax.pshuffle(mesh[-halo_size:],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
right = lax.pshuffle(mesh[:halo_size],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
mesh = jnp.concatenate([left, mesh, right], axis=0)
# Exchange along y
left = lax.pshuffle(mesh[:, -halo_size:],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
right = lax.pshuffle(mesh[:, :halo_size],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
mesh = jnp.concatenate([left, mesh, right], axis=1)
# Reading field at particles positions
res = paint.cic_read(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
return res.reshape(pos.shape[:-1])
return fn(mesh, pos)
@partial(pjit,
in_axis_resources=PartitionSpec('nx', 'ny'),
out_axis_resources=PartitionSpec('nx', None, 'ny', None))
def reshape_dense_to_split(x):
""" Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([
mesh_size['nx'], shape[0] //
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
] + shape[2:])
@partial(pjit,
in_axis_resources=PartitionSpec('nx', None, 'ny', None),
out_axis_resources=PartitionSpec('nx', 'ny'))
def reshape_split_to_dense(x):
""" Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])

View file

@ -0,0 +1,100 @@
from functools import partial
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.lax import linear_solve_p
import jaxpm.experimental.distributed_ops as dops
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import fftk
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
"""
Computes gravitational forces on particles using a PM scheme
"""
if mesh_shape is None:
mesh_shape = delta_k.shape
kvec = [k.squeeze() for k in fftk(mesh_shape, symmetric=False)]
if delta_k is None:
delta = dops.cic_paint(positions, mesh_shape, halo_size)
delta_k = dops.fft3d(dops.reshape_split_to_dense(delta))
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
# Recovers forces at particle positions
forces = [
dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
halo_size) for f in forces_k
]
return dops.stack3d(*forces)
def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
"""
Generate initial conditions.
Seed should have the dimension of the computational mesh
"""
# Sample normal field
field = dops.normal(seed, shape=mesh_shape)
# Go to Fourier space
field = dops.fft3d(dops.reshape_split_to_dense(field))
# Rescaling k to physical units
kvec = [
k.squeeze() / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
]
k = jnp.logspace(-4, 2, 256)
pk = jc.power.linear_matter_power(cosmo, k)
pk = pk * (mesh_shape[0] * mesh_shape[1] *
mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
if return_Fourier:
return field
else:
return dops.reshape_dense_to_split(dops.ifft3d(field))
def lpt(cosmo, initial_conditions, positions, a):
"""
Computes first order LPT displacement
"""
initial_force = pm_forces(positions, delta_k=initial_conditions)
a = jnp.atleast_1d(a)
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
p = dops.scalar_multiply(
dx,
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
return dx, p
def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = dops.scalar_multiply(
vel, 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
# Computes the update of velocity (kick)
dvel = dops.scalar_multiply(
forces, 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
return dpos, dvel
return nbody_ode

View file

@ -1,8 +1,8 @@
import jax.numpy as np import jax.numpy as np
from jax_cosmo.background import *
from jax_cosmo.scipy.interpolate import interp from jax_cosmo.scipy.interpolate import interp
from jax_cosmo.scipy.ode import odeint from jax_cosmo.scipy.ode import odeint
from jax_cosmo.background import *
def E(cosmo, a): def E(cosmo, a):
r"""Scale factor dependent factor E(a) in the Hubble r"""Scale factor dependent factor E(a) in the Hubble
@ -52,12 +52,8 @@ def df_de(cosmo, a, epsilon=1e-5):
\frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)- \frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)-
\frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)} \frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)}
""" """
return ( return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) /
3 np.power(np.log(a - epsilon), 2))
* cosmo.wa
* (np.log(a - epsilon) - (a - 1) / (a - epsilon))
/ np.power(np.log(a - epsilon), 2)
)
def dEa(cosmo, a): def dEa(cosmo, a):
@ -89,15 +85,11 @@ def dEa(cosmo, a):
where :math:`f(a)` is the Dark Energy evolution parameter computed where :math:`f(a)` is the Dark Energy evolution parameter computed
by :py:meth:`.f_de`. by :py:meth:`.f_de`.
""" """
return ( return (0.5 *
0.5 (-3 * cosmo.Omega_m * np.power(a, -4) -
* ( 2 * cosmo.Omega_k * np.power(a, -3) +
-3 * cosmo.Omega_m * np.power(a, -4) df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) /
- 2 * cosmo.Omega_k * np.power(a, -3) np.power(Esqr(cosmo, a), 0.5))
+ df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))
)
/ np.power(Esqr(cosmo, a), 0.5)
)
def growth_factor(cosmo, a): def growth_factor(cosmo, a):
@ -155,8 +147,7 @@ def growth_factor_second(cosmo, a):
""" """
if cosmo._flags["gamma_growth"]: if cosmo._flags["gamma_growth"]:
raise NotImplementedError( raise NotImplementedError(
"Gamma growth rate is not implemented for second order growth!" "Gamma growth rate is not implemented for second order growth!")
)
return None return None
else: else:
return _growth_factor_second_ODE(cosmo, a) return _growth_factor_second_ODE(cosmo, a)
@ -228,8 +219,7 @@ def growth_rate_second(cosmo, a):
""" """
if cosmo._flags["gamma_growth"]: if cosmo._flags["gamma_growth"]:
raise NotImplementedError( raise NotImplementedError(
"Gamma growth factor is not implemented for second order growth!" "Gamma growth factor is not implemented for second order growth!")
)
return None return None
else: else:
return _growth_rate_second_ODE(cosmo, a) return _growth_rate_second_ODE(cosmo, a)
@ -258,23 +248,19 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
atab = np.logspace(log10_amin, 0.0, steps) atab = np.logspace(log10_amin, 0.0, steps)
def D_derivs(y, x): def D_derivs(y, x):
q = ( q = (2.0 - 0.5 *
2.0 (Omega_m_a(cosmo, x) +
- 0.5 (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
* (
Omega_m_a(cosmo, x)
+ (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x)
)
) / x
r = 1.5 * Omega_m_a(cosmo, x) / x / x r = 1.5 * Omega_m_a(cosmo, x) / x / x
g1, g2 = y[0] g1, g2 = y[0]
f1, f2 = y[1] f1, f2 = y[1]
dy1da = [f1, -q * f1 + r * g1] dy1da = [f1, -q * f1 + r * g1]
dy2da = [f2, -q * f2 + r * g2 - r * g1 ** 2] dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]]) return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]]) y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
[1.0, -6.0 / 7 * atab[0]]])
y = odeint(D_derivs, y0, atab) y = odeint(D_derivs, y0, atab)
# compute second order derivatives growth # compute second order derivatives growth
@ -473,8 +459,7 @@ def _growth_rate_gamma(cosmo, a):
see :cite:`2019:Euclid Preparation VII, eqn.32` see :cite:`2019:Euclid Preparation VII, eqn.32`
""" """
return Omega_m_a(cosmo, a) ** cosmo.gamma return Omega_m_a(cosmo, a)**cosmo.gamma
def Gf(cosmo, a): def Gf(cosmo, a):
@ -503,7 +488,7 @@ def Gf(cosmo, a):
""" """
f1 = growth_rate(cosmo, a) f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a) g1 = growth_factor(cosmo, a)
D1f = f1*g1/ a D1f = f1 * g1 / a
return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
@ -532,7 +517,7 @@ def Gf2(cosmo, a):
""" """
f2 = growth_rate_second(cosmo, a) f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a) g2 = growth_factor_second(cosmo, a)
D2f = f2*g2/ a D2f = f2 * g2 / a
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
@ -563,13 +548,12 @@ def dGfa(cosmo, a):
""" """
f1 = growth_rate(cosmo, a) f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a) g1 = growth_factor(cosmo, a)
D1f = f1*g1/ a D1f = f1 * g1 / a
cache = cosmo._workspace['background.growth_factor'] cache = cosmo._workspace['background.growth_factor']
f1p = cache['h'] / cache['a'] * cache['g'] f1p = cache['h'] / cache['a'] * cache['g']
f1p = interp(np.log(a), np.log(cache['a']), f1p) f1p = interp(np.log(a), np.log(cache['a']), f1p)
Ea = E(cosmo, a) Ea = E(cosmo, a)
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f)
3 * a**2 * Ea * D1f)
def dGf2a(cosmo, a): def dGf2a(cosmo, a):
@ -599,10 +583,9 @@ def dGf2a(cosmo, a):
""" """
f2 = growth_rate_second(cosmo, a) f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a) g2 = growth_factor_second(cosmo, a)
D2f = f2*g2/ a D2f = f2 * g2 / a
cache = cosmo._workspace['background.growth_factor'] cache = cosmo._workspace['background.growth_factor']
f2p = cache['h2'] / cache['a'] * cache['g2'] f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p) f2p = interp(np.log(a), np.log(cache['a']), f2p)
E = E(cosmo, a) E = E(cosmo, a)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
3 * a**2 * E * D2f)

View file

@ -1,25 +1,27 @@
import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
def fftk(shape, symmetric=True, finite=False, dtype=np.float32): def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
""" Return k_vector given a shape (nc, nc, nc) and box_size """ Return k_vector given a shape (nc, nc, nc) and box_size
""" """
k = [] k = []
for d in range(len(shape)): for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d]) kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int') kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1: if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1] kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd) kdshape[d] = len(kd)
kd = kd.reshape(kdshape) kd = kd.reshape(kdshape)
k.append(kd.astype(dtype))
del kd, kdshape
return k
k.append(kd.astype(dtype))
del kd, kdshape
return k
def gradient_kernel(kvec, direction, order=1): def gradient_kernel(kvec, direction, order=1):
""" """
Computes the gradient kernel in the requested direction Computes the gradient kernel in the requested direction
Parameters: Parameters:
----------- -----------
@ -32,20 +34,21 @@ def gradient_kernel(kvec, direction, order=1):
wts: array wts: array
Complex kernel Complex kernel
""" """
if order == 0: if order == 0:
wts = 1j * kvec[direction] wts = 1j * kvec[direction]
wts = jnp.squeeze(wts) wts = jnp.squeeze(wts)
wts[len(wts) // 2] = 0 wts[len(wts) // 2] = 0
wts = wts.reshape(kvec[direction].shape) wts = wts.reshape(kvec[direction].shape)
return wts return wts
else: else:
w = kvec[direction] w = kvec[direction]
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
wts = a * 1j wts = a * 1j
return wts return wts
def laplace_kernel(kvec): def laplace_kernel(kvec):
""" """
Compute the Laplace kernel from a given K vector Compute the Laplace kernel from a given K vector
Parameters: Parameters:
----------- -----------
@ -56,16 +59,17 @@ def laplace_kernel(kvec):
wts: array wts: array
Complex kernel Complex kernel
""" """
kk = sum(ki**2 for ki in kvec) kk = sum(ki**2 for ki in kvec)
mask = (kk == 0).nonzero() mask = (kk == 0).nonzero()
kk[mask] = 1 kk[mask] = 1
wts = 1. / kk wts = 1. / kk
imask = (~(kk == 0)).astype(int) imask = (~(kk == 0)).astype(int)
wts *= imask wts *= imask
return wts return wts
def longrange_kernel(kvec, r_split): def longrange_kernel(kvec, r_split):
""" """
Computes a long range kernel Computes a long range kernel
Parameters: Parameters:
----------- -----------
@ -78,29 +82,31 @@ def longrange_kernel(kvec, r_split):
wts: array wts: array
kernel kernel
""" """
if r_split != 0: if r_split != 0:
kk = sum(ki**2 for ki in kvec) kk = sum(ki**2 for ki in kvec)
return np.exp(-kk * r_split**2) return np.exp(-kk * r_split**2)
else: else:
return 1. return 1.
def cic_compensation(kvec): def cic_compensation(kvec):
""" """
Computes cic compensation kernel. Computes cic compensation kernel.
Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499 Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499
Itself based on equation 18 (with p=2) of Itself based on equation 18 (with p=2) of
`Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_ `Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_
Args: Args:
kvec: array of k values in Fourier space kvec: array of k values in Fourier space
Returns: Returns:
v: array of kernel v: array of kernel
""" """
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2) wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts return wts
def PGD_kernel(kvec, kl, ks): def PGD_kernel(kvec, kl, ks):
""" """
Computes the PGD kernel Computes the PGD kernel
Parameters: Parameters:
----------- -----------
@ -115,12 +121,12 @@ def PGD_kernel(kvec, kl, ks):
v: array v: array
kernel kernel
""" """
kk = sum(ki**2 for ki in kvec) kk = sum(ki**2 for ki in kvec)
kl2 = kl**2 kl2 = kl**2
ks4 = ks**4 ks4 = ks**4
mask = (kk == 0).nonzero() mask = (kk == 0).nonzero()
kk[mask] = 1 kk[mask] = 1
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4) v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
imask = (~(kk == 0)).astype(int) imask = (~(kk == 0)).astype(int)
v *= imask v *= imask
return v return v

View file

@ -1,11 +1,12 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo.constants as constants
import jax_cosmo import jax_cosmo
import jax_cosmo.constants as constants
from jax.scipy.ndimage import map_coordinates from jax.scipy.ndimage import map_coordinates
from jaxpm.utils import gaussian_smoothing
from jaxpm.painting import cic_paint_2d from jaxpm.painting import cic_paint_2d
from jaxpm.utils import gaussian_smoothing
def density_plane(positions, def density_plane(positions,
box_shape, box_shape,
@ -26,9 +27,11 @@ def density_plane(positions,
xy = xy / nx * plane_resolution xy = xy / nx * plane_resolution
# Selecting only particles that fall inside the volume of interest # Selecting only particles that fall inside the volume of interest
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.) weight = jnp.where(
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
# Painting density plane # Painting density plane
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight) density_plane = cic_paint_2d(
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
# Apply density normalization # Apply density normalization
density_plane = density_plane / ((nx / plane_resolution) * density_plane = density_plane / ((nx / plane_resolution) *
@ -36,45 +39,44 @@ def density_plane(positions,
# Apply Gaussian smoothing if requested # Apply Gaussian smoothing if requested
if smoothing_sigma is not None: if smoothing_sigma is not None:
density_plane = gaussian_smoothing(density_plane, density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
smoothing_sigma)
return density_plane return density_plane
def convergence_Born(cosmo, def convergence_Born(cosmo, density_planes, coords, z_source):
density_planes, """
coords,
z_source):
"""
Compute the Born convergence Compute the Born convergence
Args: Args:
cosmo: `Cosmology`, cosmology object. cosmo: `Cosmology`, cosmology object.
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2]. coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
z_source: 1-D `Tensor` of source redshifts with shape [Nz] . z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
name: `string`, name of the operation. name: `string`, name of the operation.
Returns: Returns:
`Tensor` of shape [batch_size, N, Nz], of convergence values. `Tensor` of shape [batch_size, N, Nz], of convergence values.
""" """
# Compute constant prefactor: # Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2 constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies # Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source)) r_s = jax_cosmo.background.radial_comoving_distance(
cosmo, 1 / (1 + z_source))
convergence = 0 convergence = 0
for entry in density_planes: for entry in density_planes:
r = entry['r']; a = entry['a']; p = entry['plane'] r = entry['r']
dx = entry['dx']; dz = entry['dz'] a = entry['a']
# Normalize density planes p = entry['plane']
density_normalization = dz * r / a dx = entry['dx']
p = (p - p.mean()) * constant_factor * density_normalization dz = entry['dz']
# Normalize density planes
density_normalization = dz * r / a
p = (p - p.mean()) * constant_factor * density_normalization
# Interpolate at the density plane coordinates # Interpolate at the density plane coordinates
im = map_coordinates(p, im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
coords * r / dx - 0.5,
order=1, mode="wrap")
convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1]) convergence += im * jnp.clip(1. -
(r / r_s), 0, 1000).reshape([-1, 1, 1])
return convergence return convergence

View file

@ -1,6 +1,7 @@
import haiku as hk
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import haiku as hk
def _deBoorVectorized(x, t, c, p): def _deBoorVectorized(x, t, c, p):
""" """
@ -13,48 +14,47 @@ def _deBoorVectorized(x, t, c, p):
c: array of control points c: array of control points
p: degree of B-spline p: degree of B-spline
""" """
k = jnp.digitize(x, t) -1 k = jnp.digitize(x, t) - 1
d = [c[j + k - p] for j in range(0, p+1)] d = [c[j + k - p] for j in range(0, p + 1)]
for r in range(1, p+1): for r in range(1, p + 1):
for j in range(p, r-1, -1): for j in range(p, r - 1, -1):
alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p]) alpha = (x - t[j + k - p]) / (t[j + 1 + k - r] - t[j + k - p])
d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j] d[j] = (1.0 - alpha) * d[j - 1] + alpha * d[j]
return d[p] return d[p]
class NeuralSplineFourierFilter(hk.Module): class NeuralSplineFourierFilter(hk.Module):
"""A rotationally invariant filter parameterized by """A rotationally invariant filter parameterized by
a b-spline with parameters specified by a small NN.""" a b-spline with parameters specified by a small NN."""
def __init__(self, n_knots=8, latent_size=16, name=None): def __init__(self, n_knots=8, latent_size=16, name=None):
"""
n_knots: number of control points for the spline
""" """
n_knots: number of control points for the spline super().__init__(name=name)
""" self.n_knots = n_knots
super().__init__(name=name) self.latent_size = latent_size
self.n_knots = n_knots
self.latent_size = latent_size
def __call__(self, x, a): def __call__(self, x, a):
""" """
x: array, scale, normalized to fftfreq default x: array, scale, normalized to fftfreq default
a: scalar, scale factor a: scalar, scale factor
""" """
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a))) net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
net = jnp.sin(hk.Linear(self.latent_size)(net)) net = jnp.sin(hk.Linear(self.latent_size)(net))
w = hk.Linear(self.n_knots+1)(net) w = hk.Linear(self.n_knots + 1)(net)
k = hk.Linear(self.n_knots-1)(net) k = hk.Linear(self.n_knots - 1)(net)
# make sure the knots sum to 1 and are in the interval 0,1
k = jnp.concatenate([jnp.zeros((1,)),
jnp.cumsum(jax.nn.softmax(k))])
w = jnp.concatenate([jnp.zeros((1,)), # make sure the knots sum to 1 and are in the interval 0,1
w]) k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))])
# Augment with repeating points w = jnp.concatenate([jnp.zeros((1, )), w])
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3) # Augment with repeating points
ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))])
return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w,
3)

View file

@ -1,96 +1,100 @@
import jax import jax
import jax.numpy as jnp
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp
from jaxpm.kernels import fftk, cic_compensation from jaxpm.kernels import cic_compensation, fftk
def cic_paint(mesh, positions):
""" Paints positions onto mesh def cic_paint(mesh, positions, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
""" """
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[0., 0, 1], [1., 1, 0], [1., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
[0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape)) neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
dnums)
return mesh
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,8]),
dnums)
return mesh
def cic_read(mesh, positions): def cic_read(mesh, positions):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
""" """
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[0., 0, 1], [1., 1, 0], [1., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
[0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape)) neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
return (mesh[neighboor_coords[...,0],
neighboor_coords[...,1],
neighboor_coords[...,3]]*kernel).sum(axis=-1)
def cic_paint_2d(mesh, positions, weight): def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh """ Paints positions onto a 2d mesh
mesh: [nx, ny] mesh: [nx, ny]
positions: [npart, 2] positions: [npart, 2]
weight: [npart] weight: [npart]
""" """
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None: if weight is not None:
kernel = kernel * weight[...,jnp.newaxis] kernel = kernel * weight[..., jnp.newaxis]
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape)) neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0,
1))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
dnums)
return mesh
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0, 1))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,4]),
dnums)
return mesh
def compensate_cic(field): def compensate_cic(field):
""" """
Compensate for CiC painting Compensate for CiC painting
Args: Args:
field: input 3D cic-painted field field: input 3D cic-painted field
Returns: Returns:
compensated_field compensated_field
""" """
nc = field.shape nc = field.shape
kvec = fftk(nc) kvec = fftk(nc)
delta_k = jnp.fft.rfftn(field) delta_k = jnp.fft.rfftn(field)
delta_k = cic_compensation(kvec) * delta_k delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k) return jnp.fft.irfftn(delta_k)

View file

@ -1,11 +1,12 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel, PGD_kernel from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
longrange_kernel)
from jaxpm.painting import cic_paint, cic_read from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
""" """
@ -21,10 +22,14 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
delta_k = jnp.fft.rfftn(delta) delta_k = jnp.fft.rfftn(delta)
# Computes gravitational potential # Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
r_split=r_split)
# Computes gravitational forces # Computes gravitational forces
return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions) return jnp.stack([
for i in range(3)],axis=-1) cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions)
for i in range(3)
],
axis=-1)
def lpt(cosmo, initial_conditions, positions, a): def lpt(cosmo, initial_conditions, positions, a):
@ -34,25 +39,31 @@ def lpt(cosmo, initial_conditions, positions, a):
initial_force = pm_forces(positions, delta=initial_conditions) initial_force = pm_forces(positions, delta=initial_conditions)
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo,
a) * initial_force
return dx, p, f return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed): def linear_field(mesh_shape, box_size, pk, seed):
""" """
Generate initial conditions. Generate initial conditions.
""" """
kvec = fftk(mesh_shape) kvec = fftk(mesh_shape)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5 kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2]) for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape) field = jax.random.normal(seed, mesh_shape)
field = jnp.fft.rfftn(field) * pkmesh**0.5 field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field) field = jnp.fft.irfftn(field)
return field return field
def make_ode_fn(mesh_shape): def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo): def nbody_ode(state, a, cosmo):
""" """
state is a tuple (position, velocities) state is a tuple (position, velocities)
@ -63,10 +74,10 @@ def make_ode_fn(mesh_shape):
# Computes the update of position (drift) # Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick) # Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel return dpos, dvel
return nbody_ode return nbody_ode
@ -128,4 +139,3 @@ def make_neural_ode_fn(model, mesh_shape):
return dpos, dvel return dpos, dvel
return neural_nbody_ode return neural_nbody_ode

View file

@ -1,85 +1,87 @@
import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm from jax.scipy.stats import norm
__all__ = ['power_spectrum'] __all__ = ['power_spectrum']
def _initialize_pk(shape, boxsize, kmin, dk): def _initialize_pk(shape, boxsize, kmin, dk):
""" """
Helper function to initialize various (fixed) values for powerspectra... not differentiable! Helper function to initialize various (fixed) values for powerspectra... not differentiable!
""" """
I = np.eye(len(shape), dtype='int') * -2 + 1 I = np.eye(len(shape), dtype='int') * -2 + 1
W = np.empty(shape, dtype='f4') W = np.empty(shape, dtype='f4')
W[...] = 2.0 W[...] = 2.0
W[..., 0] = 1.0 W[..., 0] = 1.0
W[..., -1] = 1.0 W[..., -1] = 1.0
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2 kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
kedges = np.arange(kmin, kmax, dk) kedges = np.arange(kmin, kmax, dk)
k = [ k = [
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape) np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape) for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
] ]
kmag = sum(ki**2 for ki in k)**0.5 kmag = sum(ki**2 for ki in k)**0.5
xsum = np.zeros(len(kedges) + 1) xsum = np.zeros(len(kedges) + 1)
Nsum = np.zeros(len(kedges) + 1) Nsum = np.zeros(len(kedges) + 1)
dig = np.digitize(kmag.flat, kedges) dig = np.digitize(kmag.flat, kedges)
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size) xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size) Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges return dig, Nsum, xsum, W, k, kedges
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
""" """
Calculate the powerspectra given real space field Calculate the powerspectra given real space field
Args: Args:
field: real valued field field: real valued field
kmin: minimum k-value for binned powerspectra kmin: minimum k-value for binned powerspectra
dk: differential in each kbin dk: differential in each kbin
boxsize: length of each boxlength (can be strangly shaped?) boxsize: length of each boxlength (can be strangly shaped?)
Returns: Returns:
kbins: the central value of the bins for plotting kbins: the central value of the bins for plotting
power: real valued array of power in each bin power: real valued array of power in each bin
""" """
shape = field.shape shape = field.shape
nx, ny, nz = shape nx, ny, nz = shape
#initialze values related to powerspectra (mode bins and weights) #initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#fast fourier transform #fast fourier transform
fft_image = jnp.fft.fftn(field) fft_image = jnp.fft.fftn(field)
#absolute value of fast fourier transform #absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image)) pk = jnp.real(fft_image * jnp.conj(fft_image))
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
#calculating powerspectra Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
real = jnp.real(pk).reshape([-1]) length=xsum.size) * 1j
imag = jnp.imag(pk).reshape([-1]) Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') #normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#normalization for powerspectra #find central values of each bin
norm = np.prod(np.array(shape[:])).astype('float32')**2 kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
#find central values of each bin return kbins, P / norm
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
return kbins, P / norm
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False): def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
""" """
@ -131,18 +133,17 @@ def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=Fals
def gaussian_smoothing(im, sigma): def gaussian_smoothing(im, sigma):
""" """
im: 2d image im: 2d image
sigma: smoothing scale in px sigma: smoothing scale in px
""" """
# Compute k vector # Compute k vector
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]), kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
jnp.fft.fftfreq(im.shape[1])), jnp.fft.fftfreq(im.shape[1])),
axis=-1) axis=-1)
k = jnp.linalg.norm(kvec, axis=-1) k = jnp.linalg.norm(kvec, axis=-1)
# We compute the value of the filter at frequency k # We compute the value of the filter at frequency k
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma)) filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
filter /= filter[0,0] filter /= filter[0, 0]
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real

View file

@ -1,4 +1,4 @@
from setuptools import setup, find_packages from setuptools import find_packages, setup
setup( setup(
name='JaxPM', name='JaxPM',
@ -6,6 +6,6 @@ setup(
url='https://github.com/DifferentiableUniverseInitiative/JaxPM', url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
author='JaxPM developers', author='JaxPM developers',
description='A dead simple FastPM implementation in JAX', description='A dead simple FastPM implementation in JAX',
packages=find_packages(), packages=find_packages(),
install_requires=['jax', 'jax_cosmo'], install_requires=['jax', 'jax_cosmo'],
) )