Applying formatting

This commit is contained in:
EiffL 2024-07-09 14:54:34 -04:00
parent 835fa89aec
commit f28442bb48
14 changed files with 565 additions and 445 deletions

17
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,17 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
args: ['--parallel', '--in-place']
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)

View file

@ -4,14 +4,14 @@ This document aims to detail some of the API, implementation choices, and intern
## Objective
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
## Related Work
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
- Borg

View file

@ -1,57 +1,80 @@
# Can be executed with:
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
import jax
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
import jax.lax as lax
from jax.experimental.maps import xmap
from jax.experimental.maps import Mesh
from jax.experimental.maps import Mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
from functools import partial
jax.distributed.initialize()
cube_size = 2048
@partial(xmap,
in_axes=[...],
out_axes=['x','y', ...],
axis_sizes={'x':cube_size, 'y':cube_size},
axis_resources={'x': 'nx', 'y':'ny',
'key_x':'nx', 'key_y':'ny'})
out_axes=['x', 'y', ...],
axis_sizes={
'x': cube_size,
'y': cube_size
},
axis_resources={
'x': 'nx',
'y': 'ny',
'key_x': 'nx',
'key_y': 'ny'
})
def pnormal(key):
return jax.random.normal(key, shape=[cube_size])
@partial(xmap,
in_axes={0:'x', 1:'y'},
out_axes=['x','y', ...],
axis_resources={'x': 'nx', 'y': 'ny'})
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pfft3d(mesh):
# [x, y, z]
mesh = jnp.fft.fft(mesh) # Transform on z
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.fft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
mesh = jnp.fft.fft(mesh) # Transform on y
mesh = jnp.fft.fft(mesh) # Transform on z
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.fft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
mesh = jnp.fft.fft(mesh) # Transform on y
# [z, x, y]
return mesh
@partial(xmap,
in_axes={0:'x', 1:'y'},
out_axes=['x','y', ...],
axis_resources={'x': 'nx', 'y': 'ny'})
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pifft3d(mesh):
# [z, x, y]
mesh = jnp.fft.ifft(mesh) # Transform on y
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.ifft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
mesh = jnp.fft.ifft(mesh) # Transform on z
mesh = jnp.fft.ifft(mesh) # Transform on y
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.ifft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
mesh = jnp.fft.ifft(mesh) # Transform on z
# [x, y, z]
return mesh
key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape((2,2,2))
@ -68,6 +91,6 @@ with Mesh(devices, ('nx', 'ny')):
# mesh = pnormal(key)
# kmesh = pfft3d(mesh)
# kmesh.block_until_ready()
# jax.profiler.stop_trace()
# jax.profiler.stop_trace()
print('Done')
print('Done')

View file

@ -1,48 +1,53 @@
# Start this script with:
# mpirun -np 4 python test_script.py
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
import matplotlib.pylab as plt
import jax
import numpy as np
import jax.numpy as jnp
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as np
import tensorflow_probability as tfp
from jax.experimental.maps import mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfp = tfp.substrates.jax
tfd = tfp.distributions
def cic_paint(mesh, positions):
""" Paints positions onto mesh
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(
mesh,
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
kernel.reshape([-1, 8]), dnums)
return mesh
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords.reshape([-1,8,3]).astype('int32'),
kernel.reshape([-1,8]),
dnums)
return mesh
# And let's draw some points from some 3D distribution
dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.)
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
scale_identity_multiplier=3.)
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
f = pjit(lambda x: cic_paint(x, pos),
in_axis_resources=PartitionSpec('x', 'y', 'z'),
in_axis_resources=PartitionSpec('x', 'y', 'z'),
out_axis_resources=None)
devices = np.array(jax.devices()).reshape((2, 2, 1))
@ -51,13 +56,13 @@ devices = np.array(jax.devices()).reshape((2, 2, 1))
m = jnp.zeros([32, 32, 32])
with mesh(devices, ('x', 'y', 'z')):
# Shard the mesh, I'm not sure this is absolutely necessary
m = pjit(lambda x: x,
in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
# Shard the mesh, I'm not sure this is absolutely necessary
m = pjit(lambda x: x,
in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
# Apply the sharded CiC function
res = f(m)
# Apply the sharded CiC function
res = f(m)
plt.imshow(res.sum(axis=2))
plt.show()
plt.show()

View file

@ -1,11 +1,12 @@
import jax
import jax.numpy as jnp
import jax.lax as lax
from functools import partial
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit, PartitionSpec
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.experimental.pjit import PartitionSpec, pjit
import jaxpm.painting as paint
# TODO: add a way to configure axis resources from command line
@ -14,35 +15,59 @@ mesh_size = {'nx': 2, 'ny': 2}
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},
{0: 'x', 2: 'y'},
{0: 'x', 2: 'y'}),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def stack3d(a, b, c):
return jnp.stack([a, b, c], axis=-1)
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},[...]),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}, [...]),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def scalar_multiply(a, factor):
return a * factor
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},
{0: 'x', 2: 'y'}),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def add(a, b):
return a + b
@partial(xmap,
in_axes=['x', 'y',...],
out_axes=['x', 'y',...],
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def fft3d(mesh):
""" Performs a 3D complex Fourier transform
@ -51,7 +76,7 @@ def fft3d(mesh):
mesh: a real 3D tensor of shape [Nx, Ny, Nz]
Returns:
3D FFT of the input, note that the dimensions of the output
3D FFT of the input, note that the dimensions of the output
are tranposed.
"""
mesh = jnp.fft.fft(mesh)
@ -62,8 +87,8 @@ def fft3d(mesh):
@partial(xmap,
in_axes=['x', 'y',...],
out_axes=['x', 'y',...],
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def ifft3d(mesh):
mesh = jnp.fft.ifft(mesh)
@ -72,10 +97,15 @@ def ifft3d(mesh):
mesh = lax.all_to_all(mesh, 'x', 0, 0)
return jnp.fft.ifft(mesh).real
def normal(key, shape=[]):
@partial(xmap,
in_axes=['x', 'y',...],
out_axes={0: 'x', 2: 'y'},
in_axes=['x', 'y', ...],
out_axes={
0: 'x',
2: 'y'
},
axis_resources=axis_resources)
def fn(key):
""" Generate a distributed random normal distributions
@ -83,99 +113,126 @@ def normal(key, shape=[]):
key: array of random keys with same layout as computational mesh
shape: logical shape of array to sample
"""
return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'],
shape[1]//mesh_size['ny']]+shape[2:])
return jax.random.normal(
key,
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
shape[2:])
return fn(key)
@partial(xmap,
in_axes=(['x', 'y', ...],
[['x'], ['y'], [...]], [...], [...]),
in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
@jax.jit
def scale_by_power_spectrum(kfield, kvec, k, pk):
kx, ky, kz = kvec
kk = jnp.sqrt(kx**2 + ky ** 2 + kz**2)
kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
@partial(xmap,
in_axes=(['x', 'y', 'z'],
[['x'], ['y'], ['z']]),
in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
out_axes=(['x', 'y', 'z']),
axis_resources=axis_resources)
def gradient_laplace_kernel(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk)
return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
kernel = jnp.where(kk == 0, 1., 1. / kk)
return (kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
@partial(xmap,
in_axes=([...]),
out_axes={0: 'x', 2: 'y'},
axis_sizes={'x': mesh_size['nx'],
'y': mesh_size['ny']},
out_axes={
0: 'x',
2: 'y'
},
axis_sizes={
'x': mesh_size['nx'],
'y': mesh_size['ny']
},
axis_resources=axis_resources)
def meshgrid(x, y, z):
""" Generates a mesh grid of appropriate size for the
""" Generates a mesh grid of appropriate size for the
computational mesh we have.
"""
return jnp.stack(jnp.meshgrid(x,
y,
z), axis=-1)
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
def cic_paint(pos, mesh_shape, halo_size=0):
@partial(xmap,
in_axes=({0: 'x', 2: 'y'}),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(pos):
mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size,
mesh_shape[1]//mesh_size['ny']+2*halo_size]
+ mesh_shape[2:])
mesh = jnp.zeros([
mesh_shape[0] // mesh_size['nx'] +
2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
] + mesh_shape[2:])
# Paint particles
mesh = paint.cic_paint(mesh, pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
mesh = paint.cic_paint(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
# Perform halo exchange
# Halo exchange along x
left = lax.pshuffle(mesh[-2*halo_size:],
left = lax.pshuffle(mesh[-2 * halo_size:],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
right = lax.pshuffle(mesh[:2*halo_size],
right = lax.pshuffle(mesh[:2 * halo_size],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
mesh = mesh.at[:2*halo_size].add(left)
mesh = mesh.at[-2*halo_size:].add(right)
mesh = mesh.at[:2 * halo_size].add(left)
mesh = mesh.at[-2 * halo_size:].add(right)
# Halo exchange along y
left = lax.pshuffle(mesh[:, -2*halo_size:],
left = lax.pshuffle(mesh[:, -2 * halo_size:],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
right = lax.pshuffle(mesh[:, :2*halo_size],
right = lax.pshuffle(mesh[:, :2 * halo_size],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
mesh = mesh.at[:, :2*halo_size].add(left)
mesh = mesh.at[:, -2*halo_size:].add(right)
mesh = mesh.at[:, :2 * halo_size].add(left)
mesh = mesh.at[:, -2 * halo_size:].add(right)
# removing halo and returning mesh
return mesh[halo_size:-halo_size, halo_size:-halo_size]
return fn(pos)
def cic_read(mesh, pos, halo_size=0):
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},
{0: 'x', 2: 'y'},),
out_axes=({0: 'x', 2: 'y'}),
in_axes=(
{
0: 'x',
2: 'y'
},
{
0: 'x',
2: 'y'
},
),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(mesh, pos):
@ -198,11 +255,13 @@ def cic_read(mesh, pos, halo_size=0):
mesh = jnp.concatenate([left, mesh, right], axis=1)
# Reading field at particles positions
res = paint.cic_read(mesh, pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
res = paint.cic_read(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
return res.reshape(pos.shape[:-1])
return fn(mesh, pos)
@ -211,12 +270,14 @@ def cic_read(mesh, pos, halo_size=0):
out_axis_resources=PartitionSpec('nx', None, 'ny', None))
def reshape_dense_to_split(x):
""" Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
Changes the logical shape of the array, but no shuffling of the
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([mesh_size['nx'], shape[0]//mesh_size['nx'],
mesh_size['ny'], shape[2]//mesh_size['ny']] + shape[2:])
return x.reshape([
mesh_size['nx'], shape[0] //
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
] + shape[2:])
@partial(pjit,
@ -224,8 +285,8 @@ def reshape_dense_to_split(x):
out_axis_resources=PartitionSpec('nx', 'ny'))
def reshape_split_to_dense(x):
""" Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
Changes the logical shape of the array, but no shuffling of the
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([shape[0]*shape[1], shape[2]*shape[3]] + shape[4:])
return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])

View file

@ -1,13 +1,14 @@
import jax
from jax.lax import linear_solve_p
import jax.numpy as jnp
from jax.experimental.maps import xmap
from functools import partial
import jax_cosmo as jc
from jaxpm.kernels import fftk
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.lax import linear_solve_p
import jaxpm.experimental.distributed_ops as dops
from jaxpm.growth import growth_factor, growth_rate, dGfa
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import fftk
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
@ -25,8 +26,10 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
# Recovers forces at particle positions
forces = [dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)),
positions, halo_size) for f in forces_k]
forces = [
dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
halo_size) for f in forces_k
]
return dops.stack3d(*forces)
@ -44,12 +47,14 @@ def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
field = dops.fft3d(dops.reshape_split_to_dense(field))
# Rescaling k to physical units
kvec = [k.squeeze() / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(mesh_shape, symmetric=False))]
kvec = [
k.squeeze() / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
]
k = jnp.logspace(-4, 2, 256)
pk = jc.power.linear_matter_power(cosmo, k)
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
) / (box_size[0] * box_size[1] * box_size[2])
pk = pk * (mesh_shape[0] * mesh_shape[1] *
mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
@ -66,8 +71,9 @@ def lpt(cosmo, initial_conditions, positions, a):
initial_force = pm_forces(positions, delta_k=initial_conditions)
a = jnp.atleast_1d(a)
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
jnp.sqrt(jc.background.Esqr(cosmo, a)))
p = dops.scalar_multiply(
dx,
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
return dx, p

View file

@ -1,8 +1,8 @@
import jax.numpy as np
from jax_cosmo.background import *
from jax_cosmo.scipy.interpolate import interp
from jax_cosmo.scipy.ode import odeint
from jax_cosmo.background import *
def E(cosmo, a):
r"""Scale factor dependent factor E(a) in the Hubble
@ -52,12 +52,8 @@ def df_de(cosmo, a, epsilon=1e-5):
\frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)-
\frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)}
"""
return (
3
* cosmo.wa
* (np.log(a - epsilon) - (a - 1) / (a - epsilon))
/ np.power(np.log(a - epsilon), 2)
)
return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) /
np.power(np.log(a - epsilon), 2))
def dEa(cosmo, a):
@ -89,15 +85,11 @@ def dEa(cosmo, a):
where :math:`f(a)` is the Dark Energy evolution parameter computed
by :py:meth:`.f_de`.
"""
return (
0.5
* (
-3 * cosmo.Omega_m * np.power(a, -4)
- 2 * cosmo.Omega_k * np.power(a, -3)
+ df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))
)
/ np.power(Esqr(cosmo, a), 0.5)
)
return (0.5 *
(-3 * cosmo.Omega_m * np.power(a, -4) -
2 * cosmo.Omega_k * np.power(a, -3) +
df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) /
np.power(Esqr(cosmo, a), 0.5))
def growth_factor(cosmo, a):
@ -155,8 +147,7 @@ def growth_factor_second(cosmo, a):
"""
if cosmo._flags["gamma_growth"]:
raise NotImplementedError(
"Gamma growth rate is not implemented for second order growth!"
)
"Gamma growth rate is not implemented for second order growth!")
return None
else:
return _growth_factor_second_ODE(cosmo, a)
@ -228,8 +219,7 @@ def growth_rate_second(cosmo, a):
"""
if cosmo._flags["gamma_growth"]:
raise NotImplementedError(
"Gamma growth factor is not implemented for second order growth!"
)
"Gamma growth factor is not implemented for second order growth!")
return None
else:
return _growth_rate_second_ODE(cosmo, a)
@ -258,23 +248,19 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
atab = np.logspace(log10_amin, 0.0, steps)
def D_derivs(y, x):
q = (
2.0
- 0.5
* (
Omega_m_a(cosmo, x)
+ (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x)
)
) / x
q = (2.0 - 0.5 *
(Omega_m_a(cosmo, x) +
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
r = 1.5 * Omega_m_a(cosmo, x) / x / x
g1, g2 = y[0]
f1, f2 = y[1]
dy1da = [f1, -q * f1 + r * g1]
dy2da = [f2, -q * f2 + r * g2 - r * g1 ** 2]
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
[1.0, -6.0 / 7 * atab[0]]])
y = odeint(D_derivs, y0, atab)
# compute second order derivatives growth
@ -473,8 +459,7 @@ def _growth_rate_gamma(cosmo, a):
see :cite:`2019:Euclid Preparation VII, eqn.32`
"""
return Omega_m_a(cosmo, a) ** cosmo.gamma
return Omega_m_a(cosmo, a)**cosmo.gamma
def Gf(cosmo, a):
@ -503,7 +488,7 @@ def Gf(cosmo, a):
"""
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1*g1/ a
D1f = f1 * g1 / a
return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
@ -532,7 +517,7 @@ def Gf2(cosmo, a):
"""
f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a)
D2f = f2*g2/ a
D2f = f2 * g2 / a
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
@ -563,13 +548,12 @@ def dGfa(cosmo, a):
"""
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1*g1/ a
D1f = f1 * g1 / a
cache = cosmo._workspace['background.growth_factor']
f1p = cache['h'] / cache['a'] * cache['g']
f1p = interp(np.log(a), np.log(cache['a']), f1p)
Ea = E(cosmo, a)
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) +
3 * a**2 * Ea * D1f)
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f)
def dGf2a(cosmo, a):
@ -599,10 +583,9 @@ def dGf2a(cosmo, a):
"""
f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a)
D2f = f2*g2/ a
D2f = f2 * g2 / a
cache = cosmo._workspace['background.growth_factor']
f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p)
E = E(cosmo, a)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) +
3 * a**2 * E * D2f)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)

View file

@ -1,25 +1,27 @@
import numpy as np
import jax.numpy as jnp
import numpy as np
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
""" Return k_vector given a shape (nc, nc, nc) and box_size
""" Return k_vector given a shape (nc, nc, nc) and box_size
"""
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
k.append(kd.astype(dtype))
del kd, kdshape
return k
k.append(kd.astype(dtype))
del kd, kdshape
return k
def gradient_kernel(kvec, direction, order=1):
"""
"""
Computes the gradient kernel in the requested direction
Parameters:
-----------
@ -32,20 +34,21 @@ def gradient_kernel(kvec, direction, order=1):
wts: array
Complex kernel
"""
if order == 0:
wts = 1j * kvec[direction]
wts = jnp.squeeze(wts)
wts[len(wts) // 2] = 0
wts = wts.reshape(kvec[direction].shape)
return wts
else:
w = kvec[direction]
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
wts = a * 1j
return wts
if order == 0:
wts = 1j * kvec[direction]
wts = jnp.squeeze(wts)
wts[len(wts) // 2] = 0
wts = wts.reshape(kvec[direction].shape)
return wts
else:
w = kvec[direction]
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
wts = a * 1j
return wts
def laplace_kernel(kvec):
"""
"""
Compute the Laplace kernel from a given K vector
Parameters:
-----------
@ -56,16 +59,17 @@ def laplace_kernel(kvec):
wts: array
Complex kernel
"""
kk = sum(ki**2 for ki in kvec)
mask = (kk == 0).nonzero()
kk[mask] = 1
wts = 1. / kk
imask = (~(kk == 0)).astype(int)
wts *= imask
return wts
kk = sum(ki**2 for ki in kvec)
mask = (kk == 0).nonzero()
kk[mask] = 1
wts = 1. / kk
imask = (~(kk == 0)).astype(int)
wts *= imask
return wts
def longrange_kernel(kvec, r_split):
"""
"""
Computes a long range kernel
Parameters:
-----------
@ -78,29 +82,31 @@ def longrange_kernel(kvec, r_split):
wts: array
kernel
"""
if r_split != 0:
kk = sum(ki**2 for ki in kvec)
return np.exp(-kk * r_split**2)
else:
return 1.
if r_split != 0:
kk = sum(ki**2 for ki in kvec)
return np.exp(-kk * r_split**2)
else:
return 1.
def cic_compensation(kvec):
"""
"""
Computes cic compensation kernel.
Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499
Itself based on equation 18 (with p=2) of
`Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_
Args:
kvec: array of k values in Fourier space
kvec: array of k values in Fourier space
Returns:
v: array of kernel
"""
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts
def PGD_kernel(kvec, kl, ks):
"""
"""
Computes the PGD kernel
Parameters:
-----------
@ -115,12 +121,12 @@ def PGD_kernel(kvec, kl, ks):
v: array
kernel
"""
kk = sum(ki**2 for ki in kvec)
kl2 = kl**2
ks4 = ks**4
mask = (kk == 0).nonzero()
kk[mask] = 1
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
imask = (~(kk == 0)).astype(int)
v *= imask
return v
kk = sum(ki**2 for ki in kvec)
kl2 = kl**2
ks4 = ks**4
mask = (kk == 0).nonzero()
kk[mask] = 1
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
imask = (~(kk == 0)).astype(int)
v *= imask
return v

View file

@ -1,11 +1,12 @@
import jax
import jax
import jax.numpy as jnp
import jax_cosmo.constants as constants
import jax_cosmo
import jax_cosmo.constants as constants
from jax.scipy.ndimage import map_coordinates
from jaxpm.utils import gaussian_smoothing
from jaxpm.painting import cic_paint_2d
from jaxpm.utils import gaussian_smoothing
def density_plane(positions,
box_shape,
@ -26,9 +27,11 @@ def density_plane(positions,
xy = xy / nx * plane_resolution
# Selecting only particles that fall inside the volume of interest
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
weight = jnp.where(
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
# Painting density plane
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
density_plane = cic_paint_2d(
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
# Apply density normalization
density_plane = density_plane / ((nx / plane_resolution) *
@ -36,45 +39,44 @@ def density_plane(positions,
# Apply Gaussian smoothing if requested
if smoothing_sigma is not None:
density_plane = gaussian_smoothing(density_plane,
smoothing_sigma)
density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
return density_plane
def convergence_Born(cosmo,
density_planes,
coords,
z_source):
"""
def convergence_Born(cosmo, density_planes, coords, z_source):
"""
Compute the Born convergence
Args:
cosmo: `Cosmology`, cosmology object.
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
name: `string`, name of the operation.
Returns:
`Tensor` of shape [batch_size, N, Nz], of convergence values.
"""
# Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
# Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(
cosmo, 1 / (1 + z_source))
convergence = 0
for entry in density_planes:
r = entry['r']; a = entry['a']; p = entry['plane']
dx = entry['dx']; dz = entry['dz']
# Normalize density planes
density_normalization = dz * r / a
p = (p - p.mean()) * constant_factor * density_normalization
convergence = 0
for entry in density_planes:
r = entry['r']
a = entry['a']
p = entry['plane']
dx = entry['dx']
dz = entry['dz']
# Normalize density planes
density_normalization = dz * r / a
p = (p - p.mean()) * constant_factor * density_normalization
# Interpolate at the density plane coordinates
im = map_coordinates(p,
coords * r / dx - 0.5,
order=1, mode="wrap")
# Interpolate at the density plane coordinates
im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1])
convergence += im * jnp.clip(1. -
(r / r_s), 0, 1000).reshape([-1, 1, 1])
return convergence
return convergence

View file

@ -1,6 +1,7 @@
import haiku as hk
import jax
import jax.numpy as jnp
import haiku as hk
def _deBoorVectorized(x, t, c, p):
"""
@ -13,48 +14,47 @@ def _deBoorVectorized(x, t, c, p):
c: array of control points
p: degree of B-spline
"""
k = jnp.digitize(x, t) -1
d = [c[j + k - p] for j in range(0, p+1)]
for r in range(1, p+1):
for j in range(p, r-1, -1):
alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p])
d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j]
k = jnp.digitize(x, t) - 1
d = [c[j + k - p] for j in range(0, p + 1)]
for r in range(1, p + 1):
for j in range(p, r - 1, -1):
alpha = (x - t[j + k - p]) / (t[j + 1 + k - r] - t[j + k - p])
d[j] = (1.0 - alpha) * d[j - 1] + alpha * d[j]
return d[p]
class NeuralSplineFourierFilter(hk.Module):
"""A rotationally invariant filter parameterized by
"""A rotationally invariant filter parameterized by
a b-spline with parameters specified by a small NN."""
def __init__(self, n_knots=8, latent_size=16, name=None):
def __init__(self, n_knots=8, latent_size=16, name=None):
"""
n_knots: number of control points for the spline
"""
n_knots: number of control points for the spline
"""
super().__init__(name=name)
self.n_knots = n_knots
self.latent_size = latent_size
super().__init__(name=name)
self.n_knots = n_knots
self.latent_size = latent_size
def __call__(self, x, a):
"""
def __call__(self, x, a):
"""
x: array, scale, normalized to fftfreq default
a: scalar, scale factor
"""
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
net = jnp.sin(hk.Linear(self.latent_size)(net))
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
net = jnp.sin(hk.Linear(self.latent_size)(net))
w = hk.Linear(self.n_knots+1)(net)
k = hk.Linear(self.n_knots-1)(net)
# make sure the knots sum to 1 and are in the interval 0,1
k = jnp.concatenate([jnp.zeros((1,)),
jnp.cumsum(jax.nn.softmax(k))])
w = hk.Linear(self.n_knots + 1)(net)
k = hk.Linear(self.n_knots - 1)(net)
w = jnp.concatenate([jnp.zeros((1,)),
w])
# make sure the knots sum to 1 and are in the interval 0,1
k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))])
# Augment with repeating points
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
w = jnp.concatenate([jnp.zeros((1, )), w])
return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
# Augment with repeating points
ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))])
return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w,
3)

View file

@ -1,98 +1,100 @@
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax.numpy as jnp
from jaxpm.kernels import cic_compensation, fftk
from jaxpm.kernels import fftk, cic_compensation
def cic_paint(mesh, positions, weight=None):
""" Paints positions onto mesh
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,8]),
dnums)
return mesh
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
dnums)
return mesh
def cic_read(mesh, positions):
""" Paints positions onto mesh
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape))
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
return (mesh[neighboor_coords[...,0],
neighboor_coords[...,1],
neighboor_coords[...,3]]*kernel).sum(axis=-1)
def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh
""" Paints positions onto a 2d mesh
mesh: [nx, ny]
positions: [npart, 2]
weight: [npart]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[...,jnp.newaxis]
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[..., jnp.newaxis]
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0,
1))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
dnums)
return mesh
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0, 1))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,4]),
dnums)
return mesh
def compensate_cic(field):
"""
"""
Compensate for CiC painting
Args:
field: input 3D cic-painted field
Returns:
compensated_field
"""
nc = field.shape
kvec = fftk(nc)
nc = field.shape
kvec = fftk(nc)
delta_k = jnp.fft.rfftn(field)
delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k)
delta_k = jnp.fft.rfftn(field)
delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k)

View file

@ -1,11 +1,12 @@
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel, PGD_kernel
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
longrange_kernel)
from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
"""
@ -21,10 +22,14 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
delta_k = jnp.fft.rfftn(delta)
# Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
r_split=r_split)
# Computes gravitational forces
return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions)
for i in range(3)],axis=-1)
return jnp.stack([
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions)
for i in range(3)
],
axis=-1)
def lpt(cosmo, initial_conditions, positions, a):
@ -34,25 +39,31 @@ def lpt(cosmo, initial_conditions, positions, a):
initial_force = pm_forces(positions, delta=initial_conditions)
a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo,
a) * initial_force
return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed):
"""
Generate initial conditions.
"""
kvec = fftk(mesh_shape)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape)
field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field)
return field
def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo):
"""
state is a tuple (position, velocities)
@ -63,10 +74,10 @@ def make_ode_fn(mesh_shape):
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel
return nbody_ode
@ -84,13 +95,16 @@ def pgd_correction(pos, params):
delta = cic_paint(jnp.zeros(mesh_shape), pos)
alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta)
PGD_range=PGD_kernel(kvec, kl, ks)
pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
PGD_range = PGD_kernel(kvec, kl, ks)
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
for i in range(3)],axis=-1)
dpos_pgd = forces_pgd*alpha
return dpos_pgd
pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
forces_pgd = jnp.stack([
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
for i in range(3)
],
axis=-1)
dpos_pgd = forces_pgd * alpha
return dpos_pgd

View file

@ -1,99 +1,100 @@
import numpy as np
import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm
__all__ = ['power_spectrum']
def _initialize_pk(shape, boxsize, kmin, dk):
"""
"""
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
"""
I = np.eye(len(shape), dtype='int') * -2 + 1
I = np.eye(len(shape), dtype='int') * -2 + 1
W = np.empty(shape, dtype='f4')
W[...] = 2.0
W[..., 0] = 1.0
W[..., -1] = 1.0
W = np.empty(shape, dtype='f4')
W[...] = 2.0
W[..., 0] = 1.0
W[..., -1] = 1.0
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
kedges = np.arange(kmin, kmax, dk)
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
kedges = np.arange(kmin, kmax, dk)
k = [
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
]
kmag = sum(ki**2 for ki in k)**0.5
k = [
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
]
kmag = sum(ki**2 for ki in k)**0.5
xsum = np.zeros(len(kedges) + 1)
Nsum = np.zeros(len(kedges) + 1)
xsum = np.zeros(len(kedges) + 1)
Nsum = np.zeros(len(kedges) + 1)
dig = np.digitize(kmag.flat, kedges)
dig = np.digitize(kmag.flat, kedges)
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
"""
"""
Calculate the powerspectra given real space field
Args:
field: real valued field
field: real valued field
kmin: minimum k-value for binned powerspectra
dk: differential in each kbin
boxsize: length of each boxlength (can be strangly shaped?)
Returns:
kbins: the central value of the bins for plotting
power: real valued array of power in each bin
"""
shape = field.shape
nx, ny, nz = shape
shape = field.shape
nx, ny, nz = shape
#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#fast fourier transform
fft_image = jnp.fft.fftn(field)
#fast fourier transform
fft_image = jnp.fft.fftn(field)
#absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image))
#absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image))
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#find central values of each bin
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
#find central values of each bin
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
return kbins, P / norm
return kbins, P / norm
def gaussian_smoothing(im, sigma):
"""
"""
im: 2d image
sigma: smoothing scale in px
sigma: smoothing scale in px
"""
# Compute k vector
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
jnp.fft.fftfreq(im.shape[1])),
axis=-1)
k = jnp.linalg.norm(kvec, axis=-1)
# We compute the value of the filter at frequency k
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
filter /= filter[0,0]
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
# Compute k vector
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
jnp.fft.fftfreq(im.shape[1])),
axis=-1)
k = jnp.linalg.norm(kvec, axis=-1)
# We compute the value of the filter at frequency k
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
filter /= filter[0, 0]
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real

View file

@ -1,4 +1,4 @@
from setuptools import setup, find_packages
from setuptools import find_packages, setup
setup(
name='JaxPM',
@ -6,6 +6,6 @@ setup(
url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
author='JaxPM developers',
description='A dead simple FastPM implementation in JAX',
packages=find_packages(),
packages=find_packages(),
install_requires=['jax', 'jax_cosmo'],
)
)