Applying formatting

This commit is contained in:
EiffL 2024-07-09 14:54:34 -04:00
parent 835fa89aec
commit f28442bb48
14 changed files with 565 additions and 445 deletions

17
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,17 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
args: ['--parallel', '--in-place']
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)

View file

@ -4,14 +4,14 @@ This document aims to detail some of the API, implementation choices, and intern
## Objective ## Objective
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations. Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
## Related Work ## Related Work
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models. This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow - [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD - [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
- Borg - Borg

View file

@ -1,57 +1,80 @@
# Can be executed with: # Can be executed with:
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py # srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
import jax from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
import jax.lax as lax from jax.experimental.maps import Mesh, xmap
from jax.experimental.maps import xmap
from jax.experimental.maps import Mesh
from jax.experimental.pjit import PartitionSpec, pjit from jax.experimental.pjit import PartitionSpec, pjit
from functools import partial
jax.distributed.initialize() jax.distributed.initialize()
cube_size = 2048 cube_size = 2048
@partial(xmap, @partial(xmap,
in_axes=[...], in_axes=[...],
out_axes=['x','y', ...], out_axes=['x', 'y', ...],
axis_sizes={'x':cube_size, 'y':cube_size}, axis_sizes={
axis_resources={'x': 'nx', 'y':'ny', 'x': cube_size,
'key_x':'nx', 'key_y':'ny'}) 'y': cube_size
},
axis_resources={
'x': 'nx',
'y': 'ny',
'key_x': 'nx',
'key_y': 'ny'
})
def pnormal(key): def pnormal(key):
return jax.random.normal(key, shape=[cube_size]) return jax.random.normal(key, shape=[cube_size])
@partial(xmap, @partial(xmap,
in_axes={0:'x', 1:'y'}, in_axes={
out_axes=['x','y', ...], 0: 'x',
axis_resources={'x': 'nx', 'y': 'ny'}) 1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit @jax.jit
def pfft3d(mesh): def pfft3d(mesh):
# [x, y, z] # [x, y, z]
mesh = jnp.fft.fft(mesh) # Transform on z mesh = jnp.fft.fft(mesh) # Transform on z
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x] mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.fft(mesh) # Transform on x mesh = jnp.fft.fft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y] mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
mesh = jnp.fft.fft(mesh) # Transform on y mesh = jnp.fft.fft(mesh) # Transform on y
# [z, x, y] # [z, x, y]
return mesh return mesh
@partial(xmap, @partial(xmap,
in_axes={0:'x', 1:'y'}, in_axes={
out_axes=['x','y', ...], 0: 'x',
axis_resources={'x': 'nx', 'y': 'ny'}) 1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit @jax.jit
def pifft3d(mesh): def pifft3d(mesh):
# [z, x, y] # [z, x, y]
mesh = jnp.fft.ifft(mesh) # Transform on y mesh = jnp.fft.ifft(mesh) # Transform on y
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x] mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.ifft(mesh) # Transform on x mesh = jnp.fft.ifft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z] mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
mesh = jnp.fft.ifft(mesh) # Transform on z mesh = jnp.fft.ifft(mesh) # Transform on z
# [x, y, z] # [x, y, z]
return mesh return mesh
key = jax.random.PRNGKey(42) key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape((2,2,2)) # keys = jax.random.split(key, 4).reshape((2,2,2))
@ -68,6 +91,6 @@ with Mesh(devices, ('nx', 'ny')):
# mesh = pnormal(key) # mesh = pnormal(key)
# kmesh = pfft3d(mesh) # kmesh = pfft3d(mesh)
# kmesh.block_until_ready() # kmesh.block_until_ready()
# jax.profiler.stop_trace() # jax.profiler.stop_trace()
print('Done') print('Done')

View file

@ -1,48 +1,53 @@
# Start this script with: # Start this script with:
# mpirun -np 4 python test_script.py # mpirun -np 4 python test_script.py
import os import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
import matplotlib.pylab as plt import jax
import jax
import numpy as np
import jax.numpy as jnp
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as np
import tensorflow_probability as tfp
from jax.experimental.maps import mesh, xmap from jax.experimental.maps import mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit from jax.experimental.pjit import PartitionSpec, pjit
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfp = tfp.substrates.jax
tfd = tfp.distributions tfd = tfp.distributions
def cic_paint(mesh, positions): def cic_paint(mesh, positions):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
""" """
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[0., 0, 1], [1., 1, 0], [1., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
[0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(
mesh,
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
kernel.reshape([-1, 8]), dnums)
return mesh
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords.reshape([-1,8,3]).astype('int32'),
kernel.reshape([-1,8]),
dnums)
return mesh
# And let's draw some points from some 3D distribution # And let's draw some points from some 3D distribution
dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.) dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
scale_identity_multiplier=3.)
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0)) pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
f = pjit(lambda x: cic_paint(x, pos), f = pjit(lambda x: cic_paint(x, pos),
in_axis_resources=PartitionSpec('x', 'y', 'z'), in_axis_resources=PartitionSpec('x', 'y', 'z'),
out_axis_resources=None) out_axis_resources=None)
devices = np.array(jax.devices()).reshape((2, 2, 1)) devices = np.array(jax.devices()).reshape((2, 2, 1))
@ -51,13 +56,13 @@ devices = np.array(jax.devices()).reshape((2, 2, 1))
m = jnp.zeros([32, 32, 32]) m = jnp.zeros([32, 32, 32])
with mesh(devices, ('x', 'y', 'z')): with mesh(devices, ('x', 'y', 'z')):
# Shard the mesh, I'm not sure this is absolutely necessary # Shard the mesh, I'm not sure this is absolutely necessary
m = pjit(lambda x: x, m = pjit(lambda x: x,
in_axis_resources=None, in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m) out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
# Apply the sharded CiC function # Apply the sharded CiC function
res = f(m) res = f(m)
plt.imshow(res.sum(axis=2)) plt.imshow(res.sum(axis=2))
plt.show() plt.show()

View file

@ -1,11 +1,12 @@
import jax
import jax.numpy as jnp
import jax.lax as lax
from functools import partial from functools import partial
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit, PartitionSpec
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.experimental.pjit import PartitionSpec, pjit
import jaxpm.painting as paint import jaxpm.painting as paint
# TODO: add a way to configure axis resources from command line # TODO: add a way to configure axis resources from command line
@ -14,35 +15,59 @@ mesh_size = {'nx': 2, 'ny': 2}
@partial(xmap, @partial(xmap,
in_axes=({0: 'x', 2: 'y'}, in_axes=({
{0: 'x', 2: 'y'}, 0: 'x',
{0: 'x', 2: 'y'}), 2: 'y'
out_axes=({0: 'x', 2: 'y'}), }, {
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources) axis_resources=axis_resources)
def stack3d(a, b, c): def stack3d(a, b, c):
return jnp.stack([a, b, c], axis=-1) return jnp.stack([a, b, c], axis=-1)
@partial(xmap, @partial(xmap,
in_axes=({0: 'x', 2: 'y'},[...]), in_axes=({
out_axes=({0: 'x', 2: 'y'}), 0: 'x',
2: 'y'
}, [...]),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources) axis_resources=axis_resources)
def scalar_multiply(a, factor): def scalar_multiply(a, factor):
return a * factor return a * factor
@partial(xmap, @partial(xmap,
in_axes=({0: 'x', 2: 'y'}, in_axes=({
{0: 'x', 2: 'y'}), 0: 'x',
out_axes=({0: 'x', 2: 'y'}), 2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources) axis_resources=axis_resources)
def add(a, b): def add(a, b):
return a + b return a + b
@partial(xmap, @partial(xmap,
in_axes=['x', 'y',...], in_axes=['x', 'y', ...],
out_axes=['x', 'y',...], out_axes=['x', 'y', ...],
axis_resources=axis_resources) axis_resources=axis_resources)
def fft3d(mesh): def fft3d(mesh):
""" Performs a 3D complex Fourier transform """ Performs a 3D complex Fourier transform
@ -51,7 +76,7 @@ def fft3d(mesh):
mesh: a real 3D tensor of shape [Nx, Ny, Nz] mesh: a real 3D tensor of shape [Nx, Ny, Nz]
Returns: Returns:
3D FFT of the input, note that the dimensions of the output 3D FFT of the input, note that the dimensions of the output
are tranposed. are tranposed.
""" """
mesh = jnp.fft.fft(mesh) mesh = jnp.fft.fft(mesh)
@ -62,8 +87,8 @@ def fft3d(mesh):
@partial(xmap, @partial(xmap,
in_axes=['x', 'y',...], in_axes=['x', 'y', ...],
out_axes=['x', 'y',...], out_axes=['x', 'y', ...],
axis_resources=axis_resources) axis_resources=axis_resources)
def ifft3d(mesh): def ifft3d(mesh):
mesh = jnp.fft.ifft(mesh) mesh = jnp.fft.ifft(mesh)
@ -72,10 +97,15 @@ def ifft3d(mesh):
mesh = lax.all_to_all(mesh, 'x', 0, 0) mesh = lax.all_to_all(mesh, 'x', 0, 0)
return jnp.fft.ifft(mesh).real return jnp.fft.ifft(mesh).real
def normal(key, shape=[]): def normal(key, shape=[]):
@partial(xmap, @partial(xmap,
in_axes=['x', 'y',...], in_axes=['x', 'y', ...],
out_axes={0: 'x', 2: 'y'}, out_axes={
0: 'x',
2: 'y'
},
axis_resources=axis_resources) axis_resources=axis_resources)
def fn(key): def fn(key):
""" Generate a distributed random normal distributions """ Generate a distributed random normal distributions
@ -83,99 +113,126 @@ def normal(key, shape=[]):
key: array of random keys with same layout as computational mesh key: array of random keys with same layout as computational mesh
shape: logical shape of array to sample shape: logical shape of array to sample
""" """
return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'], return jax.random.normal(
shape[1]//mesh_size['ny']]+shape[2:]) key,
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
shape[2:])
return fn(key) return fn(key)
@partial(xmap, @partial(xmap,
in_axes=(['x', 'y', ...], in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
[['x'], ['y'], [...]], [...], [...]),
out_axes=['x', 'y', ...], out_axes=['x', 'y', ...],
axis_resources=axis_resources) axis_resources=axis_resources)
@jax.jit @jax.jit
def scale_by_power_spectrum(kfield, kvec, k, pk): def scale_by_power_spectrum(kfield, kvec, k, pk):
kx, ky, kz = kvec kx, ky, kz = kvec
kk = jnp.sqrt(kx**2 + ky ** 2 + kz**2) kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
return kfield * jc.scipy.interpolate.interp(kk, k, pk) return kfield * jc.scipy.interpolate.interp(kk, k, pk)
@partial(xmap, @partial(xmap,
in_axes=(['x', 'y', 'z'], in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
[['x'], ['y'], ['z']]),
out_axes=(['x', 'y', 'z']), out_axes=(['x', 'y', 'z']),
axis_resources=axis_resources) axis_resources=axis_resources)
def gradient_laplace_kernel(kfield, kvec): def gradient_laplace_kernel(kfield, kvec):
kx, ky, kz = kvec kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2) kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk) kernel = jnp.where(kk == 0, 1., 1. / kk)
return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), return (kfield * kernel * 1j * 1 / 6.0 *
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
(8 * jnp.sin(kz) - jnp.sin(2 * kz)), 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))) 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
@partial(xmap, @partial(xmap,
in_axes=([...]), in_axes=([...]),
out_axes={0: 'x', 2: 'y'}, out_axes={
axis_sizes={'x': mesh_size['nx'], 0: 'x',
'y': mesh_size['ny']}, 2: 'y'
},
axis_sizes={
'x': mesh_size['nx'],
'y': mesh_size['ny']
},
axis_resources=axis_resources) axis_resources=axis_resources)
def meshgrid(x, y, z): def meshgrid(x, y, z):
""" Generates a mesh grid of appropriate size for the """ Generates a mesh grid of appropriate size for the
computational mesh we have. computational mesh we have.
""" """
return jnp.stack(jnp.meshgrid(x, return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
y,
z), axis=-1)
def cic_paint(pos, mesh_shape, halo_size=0): def cic_paint(pos, mesh_shape, halo_size=0):
@partial(xmap, @partial(xmap,
in_axes=({0: 'x', 2: 'y'}), in_axes=({
out_axes=({0: 'x', 2: 'y'}), 0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources) axis_resources=axis_resources)
def fn(pos): def fn(pos):
mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size, mesh = jnp.zeros([
mesh_shape[1]//mesh_size['ny']+2*halo_size] mesh_shape[0] // mesh_size['nx'] +
+ mesh_shape[2:]) 2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
] + mesh_shape[2:])
# Paint particles # Paint particles
mesh = paint.cic_paint(mesh, pos.reshape(-1, 3) + mesh = paint.cic_paint(
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
# Perform halo exchange # Perform halo exchange
# Halo exchange along x # Halo exchange along x
left = lax.pshuffle(mesh[-2*halo_size:], left = lax.pshuffle(mesh[-2 * halo_size:],
perm=range(mesh_size['nx'])[::-1], perm=range(mesh_size['nx'])[::-1],
axis_name='x') axis_name='x')
right = lax.pshuffle(mesh[:2*halo_size], right = lax.pshuffle(mesh[:2 * halo_size],
perm=range(mesh_size['nx'])[::-1], perm=range(mesh_size['nx'])[::-1],
axis_name='x') axis_name='x')
mesh = mesh.at[:2*halo_size].add(left) mesh = mesh.at[:2 * halo_size].add(left)
mesh = mesh.at[-2*halo_size:].add(right) mesh = mesh.at[-2 * halo_size:].add(right)
# Halo exchange along y # Halo exchange along y
left = lax.pshuffle(mesh[:, -2*halo_size:], left = lax.pshuffle(mesh[:, -2 * halo_size:],
perm=range(mesh_size['ny'])[::-1], perm=range(mesh_size['ny'])[::-1],
axis_name='y') axis_name='y')
right = lax.pshuffle(mesh[:, :2*halo_size], right = lax.pshuffle(mesh[:, :2 * halo_size],
perm=range(mesh_size['ny'])[::-1], perm=range(mesh_size['ny'])[::-1],
axis_name='y') axis_name='y')
mesh = mesh.at[:, :2*halo_size].add(left) mesh = mesh.at[:, :2 * halo_size].add(left)
mesh = mesh.at[:, -2*halo_size:].add(right) mesh = mesh.at[:, -2 * halo_size:].add(right)
# removing halo and returning mesh # removing halo and returning mesh
return mesh[halo_size:-halo_size, halo_size:-halo_size] return mesh[halo_size:-halo_size, halo_size:-halo_size]
return fn(pos) return fn(pos)
def cic_read(mesh, pos, halo_size=0): def cic_read(mesh, pos, halo_size=0):
@partial(xmap, @partial(xmap,
in_axes=({0: 'x', 2: 'y'}, in_axes=(
{0: 'x', 2: 'y'},), {
out_axes=({0: 'x', 2: 'y'}), 0: 'x',
2: 'y'
},
{
0: 'x',
2: 'y'
},
),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources) axis_resources=axis_resources)
def fn(mesh, pos): def fn(mesh, pos):
@ -198,11 +255,13 @@ def cic_read(mesh, pos, halo_size=0):
mesh = jnp.concatenate([left, mesh, right], axis=1) mesh = jnp.concatenate([left, mesh, right], axis=1)
# Reading field at particles positions # Reading field at particles positions
res = paint.cic_read(mesh, pos.reshape(-1, 3) + res = paint.cic_read(
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
return res.reshape(pos.shape[:-1]) return res.reshape(pos.shape[:-1])
return fn(mesh, pos) return fn(mesh, pos)
@ -211,12 +270,14 @@ def cic_read(mesh, pos, halo_size=0):
out_axis_resources=PartitionSpec('nx', None, 'ny', None)) out_axis_resources=PartitionSpec('nx', None, 'ny', None))
def reshape_dense_to_split(x): def reshape_dense_to_split(x):
""" Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z] """ Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
Changes the logical shape of the array, but no shuffling of the Changes the logical shape of the array, but no shuffling of the
data should be necessary data should be necessary
""" """
shape = list(x.shape) shape = list(x.shape)
return x.reshape([mesh_size['nx'], shape[0]//mesh_size['nx'], return x.reshape([
mesh_size['ny'], shape[2]//mesh_size['ny']] + shape[2:]) mesh_size['nx'], shape[0] //
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
] + shape[2:])
@partial(pjit, @partial(pjit,
@ -224,8 +285,8 @@ def reshape_dense_to_split(x):
out_axis_resources=PartitionSpec('nx', 'ny')) out_axis_resources=PartitionSpec('nx', 'ny'))
def reshape_split_to_dense(x): def reshape_split_to_dense(x):
""" Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z] """ Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
Changes the logical shape of the array, but no shuffling of the Changes the logical shape of the array, but no shuffling of the
data should be necessary data should be necessary
""" """
shape = list(x.shape) shape = list(x.shape)
return x.reshape([shape[0]*shape[1], shape[2]*shape[3]] + shape[4:]) return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])

View file

@ -1,13 +1,14 @@
import jax
from jax.lax import linear_solve_p
import jax.numpy as jnp
from jax.experimental.maps import xmap
from functools import partial from functools import partial
import jax_cosmo as jc
from jaxpm.kernels import fftk import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.lax import linear_solve_p
import jaxpm.experimental.distributed_ops as dops import jaxpm.experimental.distributed_ops as dops
from jaxpm.growth import growth_factor, growth_rate, dGfa from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import fftk
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16): def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
@ -25,8 +26,10 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
forces_k = dops.gradient_laplace_kernel(delta_k, kvec) forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
# Recovers forces at particle positions # Recovers forces at particle positions
forces = [dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), forces = [
positions, halo_size) for f in forces_k] dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
halo_size) for f in forces_k
]
return dops.stack3d(*forces) return dops.stack3d(*forces)
@ -44,12 +47,14 @@ def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
field = dops.fft3d(dops.reshape_split_to_dense(field)) field = dops.fft3d(dops.reshape_split_to_dense(field))
# Rescaling k to physical units # Rescaling k to physical units
kvec = [k.squeeze() / box_size[i] * mesh_shape[i] kvec = [
for i, k in enumerate(fftk(mesh_shape, symmetric=False))] k.squeeze() / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
]
k = jnp.logspace(-4, 2, 256) k = jnp.logspace(-4, 2, 256)
pk = jc.power.linear_matter_power(cosmo, k) pk = jc.power.linear_matter_power(cosmo, k)
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2] pk = pk * (mesh_shape[0] * mesh_shape[1] *
) / (box_size[0] * box_size[1] * box_size[2]) mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk)) field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
@ -66,8 +71,9 @@ def lpt(cosmo, initial_conditions, positions, a):
initial_force = pm_forces(positions, delta_k=initial_conditions) initial_force = pm_forces(positions, delta_k=initial_conditions)
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a)) dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) * p = dops.scalar_multiply(
jnp.sqrt(jc.background.Esqr(cosmo, a))) dx,
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
return dx, p return dx, p

View file

@ -1,8 +1,8 @@
import jax.numpy as np import jax.numpy as np
from jax_cosmo.background import *
from jax_cosmo.scipy.interpolate import interp from jax_cosmo.scipy.interpolate import interp
from jax_cosmo.scipy.ode import odeint from jax_cosmo.scipy.ode import odeint
from jax_cosmo.background import *
def E(cosmo, a): def E(cosmo, a):
r"""Scale factor dependent factor E(a) in the Hubble r"""Scale factor dependent factor E(a) in the Hubble
@ -52,12 +52,8 @@ def df_de(cosmo, a, epsilon=1e-5):
\frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)- \frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)-
\frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)} \frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)}
""" """
return ( return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) /
3 np.power(np.log(a - epsilon), 2))
* cosmo.wa
* (np.log(a - epsilon) - (a - 1) / (a - epsilon))
/ np.power(np.log(a - epsilon), 2)
)
def dEa(cosmo, a): def dEa(cosmo, a):
@ -89,15 +85,11 @@ def dEa(cosmo, a):
where :math:`f(a)` is the Dark Energy evolution parameter computed where :math:`f(a)` is the Dark Energy evolution parameter computed
by :py:meth:`.f_de`. by :py:meth:`.f_de`.
""" """
return ( return (0.5 *
0.5 (-3 * cosmo.Omega_m * np.power(a, -4) -
* ( 2 * cosmo.Omega_k * np.power(a, -3) +
-3 * cosmo.Omega_m * np.power(a, -4) df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) /
- 2 * cosmo.Omega_k * np.power(a, -3) np.power(Esqr(cosmo, a), 0.5))
+ df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))
)
/ np.power(Esqr(cosmo, a), 0.5)
)
def growth_factor(cosmo, a): def growth_factor(cosmo, a):
@ -155,8 +147,7 @@ def growth_factor_second(cosmo, a):
""" """
if cosmo._flags["gamma_growth"]: if cosmo._flags["gamma_growth"]:
raise NotImplementedError( raise NotImplementedError(
"Gamma growth rate is not implemented for second order growth!" "Gamma growth rate is not implemented for second order growth!")
)
return None return None
else: else:
return _growth_factor_second_ODE(cosmo, a) return _growth_factor_second_ODE(cosmo, a)
@ -228,8 +219,7 @@ def growth_rate_second(cosmo, a):
""" """
if cosmo._flags["gamma_growth"]: if cosmo._flags["gamma_growth"]:
raise NotImplementedError( raise NotImplementedError(
"Gamma growth factor is not implemented for second order growth!" "Gamma growth factor is not implemented for second order growth!")
)
return None return None
else: else:
return _growth_rate_second_ODE(cosmo, a) return _growth_rate_second_ODE(cosmo, a)
@ -258,23 +248,19 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
atab = np.logspace(log10_amin, 0.0, steps) atab = np.logspace(log10_amin, 0.0, steps)
def D_derivs(y, x): def D_derivs(y, x):
q = ( q = (2.0 - 0.5 *
2.0 (Omega_m_a(cosmo, x) +
- 0.5 (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
* (
Omega_m_a(cosmo, x)
+ (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x)
)
) / x
r = 1.5 * Omega_m_a(cosmo, x) / x / x r = 1.5 * Omega_m_a(cosmo, x) / x / x
g1, g2 = y[0] g1, g2 = y[0]
f1, f2 = y[1] f1, f2 = y[1]
dy1da = [f1, -q * f1 + r * g1] dy1da = [f1, -q * f1 + r * g1]
dy2da = [f2, -q * f2 + r * g2 - r * g1 ** 2] dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]]) return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]]) y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
[1.0, -6.0 / 7 * atab[0]]])
y = odeint(D_derivs, y0, atab) y = odeint(D_derivs, y0, atab)
# compute second order derivatives growth # compute second order derivatives growth
@ -473,8 +459,7 @@ def _growth_rate_gamma(cosmo, a):
see :cite:`2019:Euclid Preparation VII, eqn.32` see :cite:`2019:Euclid Preparation VII, eqn.32`
""" """
return Omega_m_a(cosmo, a) ** cosmo.gamma return Omega_m_a(cosmo, a)**cosmo.gamma
def Gf(cosmo, a): def Gf(cosmo, a):
@ -503,7 +488,7 @@ def Gf(cosmo, a):
""" """
f1 = growth_rate(cosmo, a) f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a) g1 = growth_factor(cosmo, a)
D1f = f1*g1/ a D1f = f1 * g1 / a
return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
@ -532,7 +517,7 @@ def Gf2(cosmo, a):
""" """
f2 = growth_rate_second(cosmo, a) f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a) g2 = growth_factor_second(cosmo, a)
D2f = f2*g2/ a D2f = f2 * g2 / a
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
@ -563,13 +548,12 @@ def dGfa(cosmo, a):
""" """
f1 = growth_rate(cosmo, a) f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a) g1 = growth_factor(cosmo, a)
D1f = f1*g1/ a D1f = f1 * g1 / a
cache = cosmo._workspace['background.growth_factor'] cache = cosmo._workspace['background.growth_factor']
f1p = cache['h'] / cache['a'] * cache['g'] f1p = cache['h'] / cache['a'] * cache['g']
f1p = interp(np.log(a), np.log(cache['a']), f1p) f1p = interp(np.log(a), np.log(cache['a']), f1p)
Ea = E(cosmo, a) Ea = E(cosmo, a)
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f)
3 * a**2 * Ea * D1f)
def dGf2a(cosmo, a): def dGf2a(cosmo, a):
@ -599,10 +583,9 @@ def dGf2a(cosmo, a):
""" """
f2 = growth_rate_second(cosmo, a) f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a) g2 = growth_factor_second(cosmo, a)
D2f = f2*g2/ a D2f = f2 * g2 / a
cache = cosmo._workspace['background.growth_factor'] cache = cosmo._workspace['background.growth_factor']
f2p = cache['h2'] / cache['a'] * cache['g2'] f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p) f2p = interp(np.log(a), np.log(cache['a']), f2p)
E = E(cosmo, a) E = E(cosmo, a)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
3 * a**2 * E * D2f)

View file

@ -1,25 +1,27 @@
import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
def fftk(shape, symmetric=True, finite=False, dtype=np.float32): def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
""" Return k_vector given a shape (nc, nc, nc) and box_size """ Return k_vector given a shape (nc, nc, nc) and box_size
""" """
k = [] k = []
for d in range(len(shape)): for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d]) kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int') kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1: if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1] kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd) kdshape[d] = len(kd)
kd = kd.reshape(kdshape) kd = kd.reshape(kdshape)
k.append(kd.astype(dtype))
del kd, kdshape
return k
k.append(kd.astype(dtype))
del kd, kdshape
return k
def gradient_kernel(kvec, direction, order=1): def gradient_kernel(kvec, direction, order=1):
""" """
Computes the gradient kernel in the requested direction Computes the gradient kernel in the requested direction
Parameters: Parameters:
----------- -----------
@ -32,20 +34,21 @@ def gradient_kernel(kvec, direction, order=1):
wts: array wts: array
Complex kernel Complex kernel
""" """
if order == 0: if order == 0:
wts = 1j * kvec[direction] wts = 1j * kvec[direction]
wts = jnp.squeeze(wts) wts = jnp.squeeze(wts)
wts[len(wts) // 2] = 0 wts[len(wts) // 2] = 0
wts = wts.reshape(kvec[direction].shape) wts = wts.reshape(kvec[direction].shape)
return wts return wts
else: else:
w = kvec[direction] w = kvec[direction]
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
wts = a * 1j wts = a * 1j
return wts return wts
def laplace_kernel(kvec): def laplace_kernel(kvec):
""" """
Compute the Laplace kernel from a given K vector Compute the Laplace kernel from a given K vector
Parameters: Parameters:
----------- -----------
@ -56,16 +59,17 @@ def laplace_kernel(kvec):
wts: array wts: array
Complex kernel Complex kernel
""" """
kk = sum(ki**2 for ki in kvec) kk = sum(ki**2 for ki in kvec)
mask = (kk == 0).nonzero() mask = (kk == 0).nonzero()
kk[mask] = 1 kk[mask] = 1
wts = 1. / kk wts = 1. / kk
imask = (~(kk == 0)).astype(int) imask = (~(kk == 0)).astype(int)
wts *= imask wts *= imask
return wts return wts
def longrange_kernel(kvec, r_split): def longrange_kernel(kvec, r_split):
""" """
Computes a long range kernel Computes a long range kernel
Parameters: Parameters:
----------- -----------
@ -78,29 +82,31 @@ def longrange_kernel(kvec, r_split):
wts: array wts: array
kernel kernel
""" """
if r_split != 0: if r_split != 0:
kk = sum(ki**2 for ki in kvec) kk = sum(ki**2 for ki in kvec)
return np.exp(-kk * r_split**2) return np.exp(-kk * r_split**2)
else: else:
return 1. return 1.
def cic_compensation(kvec): def cic_compensation(kvec):
""" """
Computes cic compensation kernel. Computes cic compensation kernel.
Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499 Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499
Itself based on equation 18 (with p=2) of Itself based on equation 18 (with p=2) of
`Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_ `Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_
Args: Args:
kvec: array of k values in Fourier space kvec: array of k values in Fourier space
Returns: Returns:
v: array of kernel v: array of kernel
""" """
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2) wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts return wts
def PGD_kernel(kvec, kl, ks): def PGD_kernel(kvec, kl, ks):
""" """
Computes the PGD kernel Computes the PGD kernel
Parameters: Parameters:
----------- -----------
@ -115,12 +121,12 @@ def PGD_kernel(kvec, kl, ks):
v: array v: array
kernel kernel
""" """
kk = sum(ki**2 for ki in kvec) kk = sum(ki**2 for ki in kvec)
kl2 = kl**2 kl2 = kl**2
ks4 = ks**4 ks4 = ks**4
mask = (kk == 0).nonzero() mask = (kk == 0).nonzero()
kk[mask] = 1 kk[mask] = 1
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4) v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
imask = (~(kk == 0)).astype(int) imask = (~(kk == 0)).astype(int)
v *= imask v *= imask
return v return v

View file

@ -1,11 +1,12 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo.constants as constants
import jax_cosmo import jax_cosmo
import jax_cosmo.constants as constants
from jax.scipy.ndimage import map_coordinates from jax.scipy.ndimage import map_coordinates
from jaxpm.utils import gaussian_smoothing
from jaxpm.painting import cic_paint_2d from jaxpm.painting import cic_paint_2d
from jaxpm.utils import gaussian_smoothing
def density_plane(positions, def density_plane(positions,
box_shape, box_shape,
@ -26,9 +27,11 @@ def density_plane(positions,
xy = xy / nx * plane_resolution xy = xy / nx * plane_resolution
# Selecting only particles that fall inside the volume of interest # Selecting only particles that fall inside the volume of interest
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.) weight = jnp.where(
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
# Painting density plane # Painting density plane
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight) density_plane = cic_paint_2d(
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
# Apply density normalization # Apply density normalization
density_plane = density_plane / ((nx / plane_resolution) * density_plane = density_plane / ((nx / plane_resolution) *
@ -36,45 +39,44 @@ def density_plane(positions,
# Apply Gaussian smoothing if requested # Apply Gaussian smoothing if requested
if smoothing_sigma is not None: if smoothing_sigma is not None:
density_plane = gaussian_smoothing(density_plane, density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
smoothing_sigma)
return density_plane return density_plane
def convergence_Born(cosmo, def convergence_Born(cosmo, density_planes, coords, z_source):
density_planes, """
coords,
z_source):
"""
Compute the Born convergence Compute the Born convergence
Args: Args:
cosmo: `Cosmology`, cosmology object. cosmo: `Cosmology`, cosmology object.
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2]. coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
z_source: 1-D `Tensor` of source redshifts with shape [Nz] . z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
name: `string`, name of the operation. name: `string`, name of the operation.
Returns: Returns:
`Tensor` of shape [batch_size, N, Nz], of convergence values. `Tensor` of shape [batch_size, N, Nz], of convergence values.
""" """
# Compute constant prefactor: # Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2 constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies # Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source)) r_s = jax_cosmo.background.radial_comoving_distance(
cosmo, 1 / (1 + z_source))
convergence = 0 convergence = 0
for entry in density_planes: for entry in density_planes:
r = entry['r']; a = entry['a']; p = entry['plane'] r = entry['r']
dx = entry['dx']; dz = entry['dz'] a = entry['a']
# Normalize density planes p = entry['plane']
density_normalization = dz * r / a dx = entry['dx']
p = (p - p.mean()) * constant_factor * density_normalization dz = entry['dz']
# Normalize density planes
density_normalization = dz * r / a
p = (p - p.mean()) * constant_factor * density_normalization
# Interpolate at the density plane coordinates # Interpolate at the density plane coordinates
im = map_coordinates(p, im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
coords * r / dx - 0.5,
order=1, mode="wrap")
convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1]) convergence += im * jnp.clip(1. -
(r / r_s), 0, 1000).reshape([-1, 1, 1])
return convergence return convergence

View file

@ -1,6 +1,7 @@
import haiku as hk
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import haiku as hk
def _deBoorVectorized(x, t, c, p): def _deBoorVectorized(x, t, c, p):
""" """
@ -13,48 +14,47 @@ def _deBoorVectorized(x, t, c, p):
c: array of control points c: array of control points
p: degree of B-spline p: degree of B-spline
""" """
k = jnp.digitize(x, t) -1 k = jnp.digitize(x, t) - 1
d = [c[j + k - p] for j in range(0, p+1)] d = [c[j + k - p] for j in range(0, p + 1)]
for r in range(1, p+1): for r in range(1, p + 1):
for j in range(p, r-1, -1): for j in range(p, r - 1, -1):
alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p]) alpha = (x - t[j + k - p]) / (t[j + 1 + k - r] - t[j + k - p])
d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j] d[j] = (1.0 - alpha) * d[j - 1] + alpha * d[j]
return d[p] return d[p]
class NeuralSplineFourierFilter(hk.Module): class NeuralSplineFourierFilter(hk.Module):
"""A rotationally invariant filter parameterized by """A rotationally invariant filter parameterized by
a b-spline with parameters specified by a small NN.""" a b-spline with parameters specified by a small NN."""
def __init__(self, n_knots=8, latent_size=16, name=None): def __init__(self, n_knots=8, latent_size=16, name=None):
"""
n_knots: number of control points for the spline
""" """
n_knots: number of control points for the spline super().__init__(name=name)
""" self.n_knots = n_knots
super().__init__(name=name) self.latent_size = latent_size
self.n_knots = n_knots
self.latent_size = latent_size
def __call__(self, x, a): def __call__(self, x, a):
""" """
x: array, scale, normalized to fftfreq default x: array, scale, normalized to fftfreq default
a: scalar, scale factor a: scalar, scale factor
""" """
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a))) net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
net = jnp.sin(hk.Linear(self.latent_size)(net)) net = jnp.sin(hk.Linear(self.latent_size)(net))
w = hk.Linear(self.n_knots+1)(net) w = hk.Linear(self.n_knots + 1)(net)
k = hk.Linear(self.n_knots-1)(net) k = hk.Linear(self.n_knots - 1)(net)
# make sure the knots sum to 1 and are in the interval 0,1
k = jnp.concatenate([jnp.zeros((1,)),
jnp.cumsum(jax.nn.softmax(k))])
w = jnp.concatenate([jnp.zeros((1,)), # make sure the knots sum to 1 and are in the interval 0,1
w]) k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))])
# Augment with repeating points w = jnp.concatenate([jnp.zeros((1, )), w])
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3) # Augment with repeating points
ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))])
return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w,
3)

View file

@ -1,98 +1,100 @@
import jax import jax
import jax.numpy as jnp
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp
from jaxpm.kernels import cic_compensation, fftk
from jaxpm.kernels import fftk, cic_compensation
def cic_paint(mesh, positions, weight=None): def cic_paint(mesh, positions, weight=None):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
""" """
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[0., 0, 1], [1., 1, 0], [1., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
[0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None: if weight is not None:
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers( neighboor_coords = jnp.mod(
update_window_dims=(), neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
inserted_window_dims=(0, 1, 2), jnp.array(mesh.shape))
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh, dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
neighboor_coords, inserted_window_dims=(0, 1, 2),
kernel.reshape([-1,8]), scatter_dims_to_operand_dims=(0, 1,
dnums) 2))
return mesh mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
dnums)
return mesh
def cic_read(mesh, positions): def cic_read(mesh, positions):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
""" """
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[0., 0, 1], [1., 1, 0], [1., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
[0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape)) neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
return (mesh[neighboor_coords[...,0],
neighboor_coords[...,1],
neighboor_coords[...,3]]*kernel).sum(axis=-1)
def cic_paint_2d(mesh, positions, weight): def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh """ Paints positions onto a 2d mesh
mesh: [nx, ny] mesh: [nx, ny]
positions: [npart, 2] positions: [npart, 2]
weight: [npart] weight: [npart]
""" """
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None: if weight is not None:
kernel = kernel * weight[...,jnp.newaxis] kernel = kernel * weight[..., jnp.newaxis]
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape)) neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0,
1))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
dnums)
return mesh
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0, 1))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,4]),
dnums)
return mesh
def compensate_cic(field): def compensate_cic(field):
""" """
Compensate for CiC painting Compensate for CiC painting
Args: Args:
field: input 3D cic-painted field field: input 3D cic-painted field
Returns: Returns:
compensated_field compensated_field
""" """
nc = field.shape nc = field.shape
kvec = fftk(nc) kvec = fftk(nc)
delta_k = jnp.fft.rfftn(field) delta_k = jnp.fft.rfftn(field)
delta_k = cic_compensation(kvec) * delta_k delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k) return jnp.fft.irfftn(delta_k)

View file

@ -1,11 +1,12 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel, PGD_kernel from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
longrange_kernel)
from jaxpm.painting import cic_paint, cic_read from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
""" """
@ -21,10 +22,14 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
delta_k = jnp.fft.rfftn(delta) delta_k = jnp.fft.rfftn(delta)
# Computes gravitational potential # Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
r_split=r_split)
# Computes gravitational forces # Computes gravitational forces
return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions) return jnp.stack([
for i in range(3)],axis=-1) cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions)
for i in range(3)
],
axis=-1)
def lpt(cosmo, initial_conditions, positions, a): def lpt(cosmo, initial_conditions, positions, a):
@ -34,25 +39,31 @@ def lpt(cosmo, initial_conditions, positions, a):
initial_force = pm_forces(positions, delta=initial_conditions) initial_force = pm_forces(positions, delta=initial_conditions)
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo,
a) * initial_force
return dx, p, f return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed): def linear_field(mesh_shape, box_size, pk, seed):
""" """
Generate initial conditions. Generate initial conditions.
""" """
kvec = fftk(mesh_shape) kvec = fftk(mesh_shape)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5 kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2]) for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape) field = jax.random.normal(seed, mesh_shape)
field = jnp.fft.rfftn(field) * pkmesh**0.5 field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field) field = jnp.fft.irfftn(field)
return field return field
def make_ode_fn(mesh_shape): def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo): def nbody_ode(state, a, cosmo):
""" """
state is a tuple (position, velocities) state is a tuple (position, velocities)
@ -63,10 +74,10 @@ def make_ode_fn(mesh_shape):
# Computes the update of position (drift) # Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick) # Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel return dpos, dvel
return nbody_ode return nbody_ode
@ -84,13 +95,16 @@ def pgd_correction(pos, params):
delta = cic_paint(jnp.zeros(mesh_shape), pos) delta = cic_paint(jnp.zeros(mesh_shape), pos)
alpha, kl, ks = params alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta) delta_k = jnp.fft.rfftn(delta)
PGD_range=PGD_kernel(kvec, kl, ks) PGD_range = PGD_kernel(kvec, kl, ks)
pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos) pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
for i in range(3)],axis=-1)
forces_pgd = jnp.stack([
dpos_pgd = forces_pgd*alpha cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
for i in range(3)
return dpos_pgd ],
axis=-1)
dpos_pgd = forces_pgd * alpha
return dpos_pgd

View file

@ -1,99 +1,100 @@
import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm from jax.scipy.stats import norm
__all__ = ['power_spectrum'] __all__ = ['power_spectrum']
def _initialize_pk(shape, boxsize, kmin, dk): def _initialize_pk(shape, boxsize, kmin, dk):
""" """
Helper function to initialize various (fixed) values for powerspectra... not differentiable! Helper function to initialize various (fixed) values for powerspectra... not differentiable!
""" """
I = np.eye(len(shape), dtype='int') * -2 + 1 I = np.eye(len(shape), dtype='int') * -2 + 1
W = np.empty(shape, dtype='f4') W = np.empty(shape, dtype='f4')
W[...] = 2.0 W[...] = 2.0
W[..., 0] = 1.0 W[..., 0] = 1.0
W[..., -1] = 1.0 W[..., -1] = 1.0
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2 kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
kedges = np.arange(kmin, kmax, dk) kedges = np.arange(kmin, kmax, dk)
k = [ k = [
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape) np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape) for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
] ]
kmag = sum(ki**2 for ki in k)**0.5 kmag = sum(ki**2 for ki in k)**0.5
xsum = np.zeros(len(kedges) + 1) xsum = np.zeros(len(kedges) + 1)
Nsum = np.zeros(len(kedges) + 1) Nsum = np.zeros(len(kedges) + 1)
dig = np.digitize(kmag.flat, kedges) dig = np.digitize(kmag.flat, kedges)
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size) xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size) Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges return dig, Nsum, xsum, W, k, kedges
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
""" """
Calculate the powerspectra given real space field Calculate the powerspectra given real space field
Args: Args:
field: real valued field field: real valued field
kmin: minimum k-value for binned powerspectra kmin: minimum k-value for binned powerspectra
dk: differential in each kbin dk: differential in each kbin
boxsize: length of each boxlength (can be strangly shaped?) boxsize: length of each boxlength (can be strangly shaped?)
Returns: Returns:
kbins: the central value of the bins for plotting kbins: the central value of the bins for plotting
power: real valued array of power in each bin power: real valued array of power in each bin
""" """
shape = field.shape shape = field.shape
nx, ny, nz = shape nx, ny, nz = shape
#initialze values related to powerspectra (mode bins and weights) #initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#fast fourier transform #fast fourier transform
fft_image = jnp.fft.fftn(field) fft_image = jnp.fft.fftn(field)
#absolute value of fast fourier transform #absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image)) pk = jnp.real(fft_image * jnp.conj(fft_image))
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
#calculating powerspectra Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
real = jnp.real(pk).reshape([-1]) length=xsum.size) * 1j
imag = jnp.imag(pk).reshape([-1]) Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') #normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#normalization for powerspectra #find central values of each bin
norm = np.prod(np.array(shape[:])).astype('float32')**2 kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
#find central values of each bin return kbins, P / norm
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
return kbins, P / norm
def gaussian_smoothing(im, sigma): def gaussian_smoothing(im, sigma):
""" """
im: 2d image im: 2d image
sigma: smoothing scale in px sigma: smoothing scale in px
""" """
# Compute k vector # Compute k vector
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]), kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
jnp.fft.fftfreq(im.shape[1])), jnp.fft.fftfreq(im.shape[1])),
axis=-1) axis=-1)
k = jnp.linalg.norm(kvec, axis=-1) k = jnp.linalg.norm(kvec, axis=-1)
# We compute the value of the filter at frequency k # We compute the value of the filter at frequency k
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma)) filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
filter /= filter[0,0] filter /= filter[0, 0]
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real

View file

@ -1,4 +1,4 @@
from setuptools import setup, find_packages from setuptools import find_packages, setup
setup( setup(
name='JaxPM', name='JaxPM',
@ -6,6 +6,6 @@ setup(
url='https://github.com/DifferentiableUniverseInitiative/JaxPM', url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
author='JaxPM developers', author='JaxPM developers',
description='A dead simple FastPM implementation in JAX', description='A dead simple FastPM implementation in JAX',
packages=find_packages(), packages=find_packages(),
install_requires=['jax', 'jax_cosmo'], install_requires=['jax', 'jax_cosmo'],
) )