forked from guilhem_lavaux/JaxPM
Merge branch 'main' into neural_ode
This commit is contained in:
commit
9a279d2d6c
16 changed files with 858 additions and 328 deletions
17
.pre-commit-config.yaml
Normal file
17
.pre-commit-config.yaml
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v2.3.0
|
||||||
|
hooks:
|
||||||
|
- id: check-yaml
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- repo: https://github.com/google/yapf
|
||||||
|
rev: v0.40.2
|
||||||
|
hooks:
|
||||||
|
- id: yapf
|
||||||
|
args: ['--parallel', '--in-place']
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: 5.13.2
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
name: isort (python)
|
|
@ -4,14 +4,14 @@ This document aims to detail some of the API, implementation choices, and intern
|
||||||
|
|
||||||
## Objective
|
## Objective
|
||||||
|
|
||||||
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
|
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
|
||||||
|
|
||||||
## Related Work
|
## Related Work
|
||||||
|
|
||||||
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
|
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
|
||||||
|
|
||||||
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
|
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
|
||||||
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
|
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
|
||||||
- Borg
|
- Borg
|
||||||
|
|
||||||
|
|
||||||
|
|
14
dev/job_pfft.sh
Normal file
14
dev/job_pfft.sh
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
#!/bin/bash
|
||||||
|
#SBATCH -A m1727
|
||||||
|
#SBATCH -C gpu
|
||||||
|
#SBATCH -q debug
|
||||||
|
#SBATCH -t 0:05:00
|
||||||
|
#SBATCH -N 2
|
||||||
|
#SBATCH --ntasks-per-node=4
|
||||||
|
#SBATCH -c 32
|
||||||
|
#SBATCH --gpus-per-task=1
|
||||||
|
#SBATCH --gpu-bind=none
|
||||||
|
|
||||||
|
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
|
||||||
|
export SLURM_CPU_BIND="cores"
|
||||||
|
srun python test_pfft.py
|
96
dev/test_pfft.py
Normal file
96
dev/test_pfft.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
# Can be executed with:
|
||||||
|
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.lax as lax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.maps import Mesh, xmap
|
||||||
|
from jax.experimental.pjit import PartitionSpec, pjit
|
||||||
|
|
||||||
|
jax.distributed.initialize()
|
||||||
|
|
||||||
|
cube_size = 2048
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=[...],
|
||||||
|
out_axes=['x', 'y', ...],
|
||||||
|
axis_sizes={
|
||||||
|
'x': cube_size,
|
||||||
|
'y': cube_size
|
||||||
|
},
|
||||||
|
axis_resources={
|
||||||
|
'x': 'nx',
|
||||||
|
'y': 'ny',
|
||||||
|
'key_x': 'nx',
|
||||||
|
'key_y': 'ny'
|
||||||
|
})
|
||||||
|
def pnormal(key):
|
||||||
|
return jax.random.normal(key, shape=[cube_size])
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes={
|
||||||
|
0: 'x',
|
||||||
|
1: 'y'
|
||||||
|
},
|
||||||
|
out_axes=['x', 'y', ...],
|
||||||
|
axis_resources={
|
||||||
|
'x': 'nx',
|
||||||
|
'y': 'ny'
|
||||||
|
})
|
||||||
|
@jax.jit
|
||||||
|
def pfft3d(mesh):
|
||||||
|
# [x, y, z]
|
||||||
|
mesh = jnp.fft.fft(mesh) # Transform on z
|
||||||
|
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
|
||||||
|
mesh = jnp.fft.fft(mesh) # Transform on x
|
||||||
|
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
|
||||||
|
mesh = jnp.fft.fft(mesh) # Transform on y
|
||||||
|
# [z, x, y]
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes={
|
||||||
|
0: 'x',
|
||||||
|
1: 'y'
|
||||||
|
},
|
||||||
|
out_axes=['x', 'y', ...],
|
||||||
|
axis_resources={
|
||||||
|
'x': 'nx',
|
||||||
|
'y': 'ny'
|
||||||
|
})
|
||||||
|
@jax.jit
|
||||||
|
def pifft3d(mesh):
|
||||||
|
# [z, x, y]
|
||||||
|
mesh = jnp.fft.ifft(mesh) # Transform on y
|
||||||
|
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
|
||||||
|
mesh = jnp.fft.ifft(mesh) # Transform on x
|
||||||
|
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
|
||||||
|
mesh = jnp.fft.ifft(mesh) # Transform on z
|
||||||
|
# [x, y, z]
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
key = jax.random.PRNGKey(42)
|
||||||
|
# keys = jax.random.split(key, 4).reshape((2,2,2))
|
||||||
|
|
||||||
|
# We reshape all our devices to the mesh shape we want
|
||||||
|
devices = np.array(jax.devices()).reshape((2, 4))
|
||||||
|
|
||||||
|
with Mesh(devices, ('nx', 'ny')):
|
||||||
|
mesh = pnormal(key)
|
||||||
|
kmesh = pfft3d(mesh)
|
||||||
|
kmesh.block_until_ready()
|
||||||
|
|
||||||
|
# jax.profiler.start_trace("tensorboard")
|
||||||
|
# with Mesh(devices, ('nx', 'ny')):
|
||||||
|
# mesh = pnormal(key)
|
||||||
|
# kmesh = pfft3d(mesh)
|
||||||
|
# kmesh.block_until_ready()
|
||||||
|
# jax.profiler.stop_trace()
|
||||||
|
|
||||||
|
print('Done')
|
|
@ -1,48 +1,53 @@
|
||||||
# Start this script with:
|
# Start this script with:
|
||||||
# mpirun -np 4 python test_script.py
|
# mpirun -np 4 python test_script.py
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
|
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
|
||||||
import matplotlib.pylab as plt
|
import jax
|
||||||
import jax
|
|
||||||
import numpy as np
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import jax.lax as lax
|
import jax.lax as lax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import matplotlib.pylab as plt
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow_probability as tfp
|
||||||
from jax.experimental.maps import mesh, xmap
|
from jax.experimental.maps import mesh, xmap
|
||||||
from jax.experimental.pjit import PartitionSpec, pjit
|
from jax.experimental.pjit import PartitionSpec, pjit
|
||||||
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
|
|
||||||
|
tfp = tfp.substrates.jax
|
||||||
tfd = tfp.distributions
|
tfd = tfp.distributions
|
||||||
|
|
||||||
|
|
||||||
def cic_paint(mesh, positions):
|
def cic_paint(mesh, positions):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
positions: [npart, 3]
|
||||||
"""
|
"""
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||||
[0., 1, 1], [1., 1, 1]]])
|
|
||||||
|
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
|
||||||
|
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||||
|
inserted_window_dims=(0, 1, 2),
|
||||||
|
scatter_dims_to_operand_dims=(0, 1,
|
||||||
|
2))
|
||||||
|
mesh = lax.scatter_add(
|
||||||
|
mesh,
|
||||||
|
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||||
|
kernel.reshape([-1, 8]), dnums)
|
||||||
|
return mesh
|
||||||
|
|
||||||
dnums = jax.lax.ScatterDimensionNumbers(
|
|
||||||
update_window_dims=(),
|
|
||||||
inserted_window_dims=(0, 1, 2),
|
|
||||||
scatter_dims_to_operand_dims=(0, 1, 2))
|
|
||||||
mesh = lax.scatter_add(mesh,
|
|
||||||
neighboor_coords.reshape([-1,8,3]).astype('int32'),
|
|
||||||
kernel.reshape([-1,8]),
|
|
||||||
dnums)
|
|
||||||
return mesh
|
|
||||||
|
|
||||||
# And let's draw some points from some 3D distribution
|
# And let's draw some points from some 3D distribution
|
||||||
dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.)
|
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
|
||||||
|
scale_identity_multiplier=3.)
|
||||||
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
|
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
|
||||||
|
|
||||||
f = pjit(lambda x: cic_paint(x, pos),
|
f = pjit(lambda x: cic_paint(x, pos),
|
||||||
in_axis_resources=PartitionSpec('x', 'y', 'z'),
|
in_axis_resources=PartitionSpec('x', 'y', 'z'),
|
||||||
out_axis_resources=None)
|
out_axis_resources=None)
|
||||||
|
|
||||||
devices = np.array(jax.devices()).reshape((2, 2, 1))
|
devices = np.array(jax.devices()).reshape((2, 2, 1))
|
||||||
|
@ -51,13 +56,13 @@ devices = np.array(jax.devices()).reshape((2, 2, 1))
|
||||||
m = jnp.zeros([32, 32, 32])
|
m = jnp.zeros([32, 32, 32])
|
||||||
|
|
||||||
with mesh(devices, ('x', 'y', 'z')):
|
with mesh(devices, ('x', 'y', 'z')):
|
||||||
# Shard the mesh, I'm not sure this is absolutely necessary
|
# Shard the mesh, I'm not sure this is absolutely necessary
|
||||||
m = pjit(lambda x: x,
|
m = pjit(lambda x: x,
|
||||||
in_axis_resources=None,
|
in_axis_resources=None,
|
||||||
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
|
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
|
||||||
|
|
||||||
# Apply the sharded CiC function
|
# Apply the sharded CiC function
|
||||||
res = f(m)
|
res = f(m)
|
||||||
|
|
||||||
plt.imshow(res.sum(axis=2))
|
plt.imshow(res.sum(axis=2))
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
0
jaxpm/experimental/__init__.py
Normal file
0
jaxpm/experimental/__init__.py
Normal file
292
jaxpm/experimental/distributed_ops.py
Normal file
292
jaxpm/experimental/distributed_ops.py
Normal file
|
@ -0,0 +1,292 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.lax as lax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax_cosmo as jc
|
||||||
|
from jax.experimental.maps import xmap
|
||||||
|
from jax.experimental.pjit import PartitionSpec, pjit
|
||||||
|
|
||||||
|
import jaxpm.painting as paint
|
||||||
|
|
||||||
|
# TODO: add a way to configure axis resources from command line
|
||||||
|
axis_resources = {'x': 'nx', 'y': 'ny'}
|
||||||
|
mesh_size = {'nx': 2, 'ny': 2}
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=({
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}, {
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}, {
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}),
|
||||||
|
out_axes=({
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def stack3d(a, b, c):
|
||||||
|
return jnp.stack([a, b, c], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=({
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}, [...]),
|
||||||
|
out_axes=({
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def scalar_multiply(a, factor):
|
||||||
|
return a * factor
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=({
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}, {
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}),
|
||||||
|
out_axes=({
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def add(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=['x', 'y', ...],
|
||||||
|
out_axes=['x', 'y', ...],
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def fft3d(mesh):
|
||||||
|
""" Performs a 3D complex Fourier transform
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh: a real 3D tensor of shape [Nx, Ny, Nz]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
3D FFT of the input, note that the dimensions of the output
|
||||||
|
are tranposed.
|
||||||
|
"""
|
||||||
|
mesh = jnp.fft.fft(mesh)
|
||||||
|
mesh = lax.all_to_all(mesh, 'x', 0, 0)
|
||||||
|
mesh = jnp.fft.fft(mesh)
|
||||||
|
mesh = lax.all_to_all(mesh, 'y', 0, 0)
|
||||||
|
return jnp.fft.fft(mesh) # Note the output is transposed # [z, x, y]
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=['x', 'y', ...],
|
||||||
|
out_axes=['x', 'y', ...],
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def ifft3d(mesh):
|
||||||
|
mesh = jnp.fft.ifft(mesh)
|
||||||
|
mesh = lax.all_to_all(mesh, 'y', 0, 0)
|
||||||
|
mesh = jnp.fft.ifft(mesh)
|
||||||
|
mesh = lax.all_to_all(mesh, 'x', 0, 0)
|
||||||
|
return jnp.fft.ifft(mesh).real
|
||||||
|
|
||||||
|
|
||||||
|
def normal(key, shape=[]):
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=['x', 'y', ...],
|
||||||
|
out_axes={
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
},
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def fn(key):
|
||||||
|
""" Generate a distributed random normal distributions
|
||||||
|
Args:
|
||||||
|
key: array of random keys with same layout as computational mesh
|
||||||
|
shape: logical shape of array to sample
|
||||||
|
"""
|
||||||
|
return jax.random.normal(
|
||||||
|
key,
|
||||||
|
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
|
||||||
|
shape[2:])
|
||||||
|
|
||||||
|
return fn(key)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
|
||||||
|
out_axes=['x', 'y', ...],
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
@jax.jit
|
||||||
|
def scale_by_power_spectrum(kfield, kvec, k, pk):
|
||||||
|
kx, ky, kz = kvec
|
||||||
|
kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
|
||||||
|
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
|
||||||
|
out_axes=(['x', 'y', 'z']),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def gradient_laplace_kernel(kfield, kvec):
|
||||||
|
kx, ky, kz = kvec
|
||||||
|
kk = (kx**2 + ky**2 + kz**2)
|
||||||
|
kernel = jnp.where(kk == 0, 1., 1. / kk)
|
||||||
|
return (kfield * kernel * 1j * 1 / 6.0 *
|
||||||
|
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
|
||||||
|
6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
|
||||||
|
1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=([...]),
|
||||||
|
out_axes={
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
},
|
||||||
|
axis_sizes={
|
||||||
|
'x': mesh_size['nx'],
|
||||||
|
'y': mesh_size['ny']
|
||||||
|
},
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def meshgrid(x, y, z):
|
||||||
|
""" Generates a mesh grid of appropriate size for the
|
||||||
|
computational mesh we have.
|
||||||
|
"""
|
||||||
|
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def cic_paint(pos, mesh_shape, halo_size=0):
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=({
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}),
|
||||||
|
out_axes=({
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def fn(pos):
|
||||||
|
|
||||||
|
mesh = jnp.zeros([
|
||||||
|
mesh_shape[0] // mesh_size['nx'] +
|
||||||
|
2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
|
||||||
|
] + mesh_shape[2:])
|
||||||
|
|
||||||
|
# Paint particles
|
||||||
|
mesh = paint.cic_paint(
|
||||||
|
mesh,
|
||||||
|
pos.reshape(-1, 3) +
|
||||||
|
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||||
|
|
||||||
|
# Perform halo exchange
|
||||||
|
# Halo exchange along x
|
||||||
|
left = lax.pshuffle(mesh[-2 * halo_size:],
|
||||||
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
|
axis_name='x')
|
||||||
|
right = lax.pshuffle(mesh[:2 * halo_size],
|
||||||
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
|
axis_name='x')
|
||||||
|
mesh = mesh.at[:2 * halo_size].add(left)
|
||||||
|
mesh = mesh.at[-2 * halo_size:].add(right)
|
||||||
|
|
||||||
|
# Halo exchange along y
|
||||||
|
left = lax.pshuffle(mesh[:, -2 * halo_size:],
|
||||||
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
|
axis_name='y')
|
||||||
|
right = lax.pshuffle(mesh[:, :2 * halo_size],
|
||||||
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
|
axis_name='y')
|
||||||
|
mesh = mesh.at[:, :2 * halo_size].add(left)
|
||||||
|
mesh = mesh.at[:, -2 * halo_size:].add(right)
|
||||||
|
|
||||||
|
# removing halo and returning mesh
|
||||||
|
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||||
|
|
||||||
|
return fn(pos)
|
||||||
|
|
||||||
|
|
||||||
|
def cic_read(mesh, pos, halo_size=0):
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=(
|
||||||
|
{
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
},
|
||||||
|
),
|
||||||
|
out_axes=({
|
||||||
|
0: 'x',
|
||||||
|
2: 'y'
|
||||||
|
}),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def fn(mesh, pos):
|
||||||
|
|
||||||
|
# Halo exchange to grab neighboring borders
|
||||||
|
# Exchange along x
|
||||||
|
left = lax.pshuffle(mesh[-halo_size:],
|
||||||
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
|
axis_name='x')
|
||||||
|
right = lax.pshuffle(mesh[:halo_size],
|
||||||
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
|
axis_name='x')
|
||||||
|
mesh = jnp.concatenate([left, mesh, right], axis=0)
|
||||||
|
# Exchange along y
|
||||||
|
left = lax.pshuffle(mesh[:, -halo_size:],
|
||||||
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
|
axis_name='y')
|
||||||
|
right = lax.pshuffle(mesh[:, :halo_size],
|
||||||
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
|
axis_name='y')
|
||||||
|
mesh = jnp.concatenate([left, mesh, right], axis=1)
|
||||||
|
|
||||||
|
# Reading field at particles positions
|
||||||
|
res = paint.cic_read(
|
||||||
|
mesh,
|
||||||
|
pos.reshape(-1, 3) +
|
||||||
|
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||||
|
|
||||||
|
return res.reshape(pos.shape[:-1])
|
||||||
|
|
||||||
|
return fn(mesh, pos)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(pjit,
|
||||||
|
in_axis_resources=PartitionSpec('nx', 'ny'),
|
||||||
|
out_axis_resources=PartitionSpec('nx', None, 'ny', None))
|
||||||
|
def reshape_dense_to_split(x):
|
||||||
|
""" Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
|
||||||
|
Changes the logical shape of the array, but no shuffling of the
|
||||||
|
data should be necessary
|
||||||
|
"""
|
||||||
|
shape = list(x.shape)
|
||||||
|
return x.reshape([
|
||||||
|
mesh_size['nx'], shape[0] //
|
||||||
|
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
|
||||||
|
] + shape[2:])
|
||||||
|
|
||||||
|
|
||||||
|
@partial(pjit,
|
||||||
|
in_axis_resources=PartitionSpec('nx', None, 'ny', None),
|
||||||
|
out_axis_resources=PartitionSpec('nx', 'ny'))
|
||||||
|
def reshape_split_to_dense(x):
|
||||||
|
""" Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
|
||||||
|
Changes the logical shape of the array, but no shuffling of the
|
||||||
|
data should be necessary
|
||||||
|
"""
|
||||||
|
shape = list(x.shape)
|
||||||
|
return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])
|
100
jaxpm/experimental/distributed_pm.py
Normal file
100
jaxpm/experimental/distributed_pm.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax_cosmo as jc
|
||||||
|
from jax.experimental.maps import xmap
|
||||||
|
from jax.lax import linear_solve_p
|
||||||
|
|
||||||
|
import jaxpm.experimental.distributed_ops as dops
|
||||||
|
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
||||||
|
from jaxpm.kernels import fftk
|
||||||
|
|
||||||
|
|
||||||
|
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
|
||||||
|
"""
|
||||||
|
Computes gravitational forces on particles using a PM scheme
|
||||||
|
"""
|
||||||
|
if mesh_shape is None:
|
||||||
|
mesh_shape = delta_k.shape
|
||||||
|
kvec = [k.squeeze() for k in fftk(mesh_shape, symmetric=False)]
|
||||||
|
|
||||||
|
if delta_k is None:
|
||||||
|
delta = dops.cic_paint(positions, mesh_shape, halo_size)
|
||||||
|
delta_k = dops.fft3d(dops.reshape_split_to_dense(delta))
|
||||||
|
|
||||||
|
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
|
||||||
|
|
||||||
|
# Recovers forces at particle positions
|
||||||
|
forces = [
|
||||||
|
dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
|
||||||
|
halo_size) for f in forces_k
|
||||||
|
]
|
||||||
|
|
||||||
|
return dops.stack3d(*forces)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
|
||||||
|
"""
|
||||||
|
Generate initial conditions.
|
||||||
|
Seed should have the dimension of the computational mesh
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Sample normal field
|
||||||
|
field = dops.normal(seed, shape=mesh_shape)
|
||||||
|
|
||||||
|
# Go to Fourier space
|
||||||
|
field = dops.fft3d(dops.reshape_split_to_dense(field))
|
||||||
|
|
||||||
|
# Rescaling k to physical units
|
||||||
|
kvec = [
|
||||||
|
k.squeeze() / box_size[i] * mesh_shape[i]
|
||||||
|
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
|
||||||
|
]
|
||||||
|
k = jnp.logspace(-4, 2, 256)
|
||||||
|
pk = jc.power.linear_matter_power(cosmo, k)
|
||||||
|
pk = pk * (mesh_shape[0] * mesh_shape[1] *
|
||||||
|
mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
|
||||||
|
|
||||||
|
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
|
||||||
|
|
||||||
|
if return_Fourier:
|
||||||
|
return field
|
||||||
|
else:
|
||||||
|
return dops.reshape_dense_to_split(dops.ifft3d(field))
|
||||||
|
|
||||||
|
|
||||||
|
def lpt(cosmo, initial_conditions, positions, a):
|
||||||
|
"""
|
||||||
|
Computes first order LPT displacement
|
||||||
|
"""
|
||||||
|
initial_force = pm_forces(positions, delta_k=initial_conditions)
|
||||||
|
a = jnp.atleast_1d(a)
|
||||||
|
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
|
||||||
|
p = dops.scalar_multiply(
|
||||||
|
dx,
|
||||||
|
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
|
||||||
|
return dx, p
|
||||||
|
|
||||||
|
|
||||||
|
def make_ode_fn(mesh_shape):
|
||||||
|
|
||||||
|
def nbody_ode(state, a, cosmo):
|
||||||
|
"""
|
||||||
|
state is a tuple (position, velocities)
|
||||||
|
"""
|
||||||
|
pos, vel = state
|
||||||
|
|
||||||
|
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
|
# Computes the update of position (drift)
|
||||||
|
dpos = dops.scalar_multiply(
|
||||||
|
vel, 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
|
||||||
|
|
||||||
|
# Computes the update of velocity (kick)
|
||||||
|
dvel = dops.scalar_multiply(
|
||||||
|
forces, 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
|
||||||
|
|
||||||
|
return dpos, dvel
|
||||||
|
|
||||||
|
return nbody_ode
|
|
@ -1,8 +1,8 @@
|
||||||
import jax.numpy as np
|
import jax.numpy as np
|
||||||
|
from jax_cosmo.background import *
|
||||||
from jax_cosmo.scipy.interpolate import interp
|
from jax_cosmo.scipy.interpolate import interp
|
||||||
from jax_cosmo.scipy.ode import odeint
|
from jax_cosmo.scipy.ode import odeint
|
||||||
from jax_cosmo.background import *
|
|
||||||
|
|
||||||
def E(cosmo, a):
|
def E(cosmo, a):
|
||||||
r"""Scale factor dependent factor E(a) in the Hubble
|
r"""Scale factor dependent factor E(a) in the Hubble
|
||||||
|
@ -52,12 +52,8 @@ def df_de(cosmo, a, epsilon=1e-5):
|
||||||
\frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)-
|
\frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)-
|
||||||
\frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)}
|
\frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)}
|
||||||
"""
|
"""
|
||||||
return (
|
return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) /
|
||||||
3
|
np.power(np.log(a - epsilon), 2))
|
||||||
* cosmo.wa
|
|
||||||
* (np.log(a - epsilon) - (a - 1) / (a - epsilon))
|
|
||||||
/ np.power(np.log(a - epsilon), 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def dEa(cosmo, a):
|
def dEa(cosmo, a):
|
||||||
|
@ -89,15 +85,11 @@ def dEa(cosmo, a):
|
||||||
where :math:`f(a)` is the Dark Energy evolution parameter computed
|
where :math:`f(a)` is the Dark Energy evolution parameter computed
|
||||||
by :py:meth:`.f_de`.
|
by :py:meth:`.f_de`.
|
||||||
"""
|
"""
|
||||||
return (
|
return (0.5 *
|
||||||
0.5
|
(-3 * cosmo.Omega_m * np.power(a, -4) -
|
||||||
* (
|
2 * cosmo.Omega_k * np.power(a, -3) +
|
||||||
-3 * cosmo.Omega_m * np.power(a, -4)
|
df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) /
|
||||||
- 2 * cosmo.Omega_k * np.power(a, -3)
|
np.power(Esqr(cosmo, a), 0.5))
|
||||||
+ df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))
|
|
||||||
)
|
|
||||||
/ np.power(Esqr(cosmo, a), 0.5)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def growth_factor(cosmo, a):
|
def growth_factor(cosmo, a):
|
||||||
|
@ -155,8 +147,7 @@ def growth_factor_second(cosmo, a):
|
||||||
"""
|
"""
|
||||||
if cosmo._flags["gamma_growth"]:
|
if cosmo._flags["gamma_growth"]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Gamma growth rate is not implemented for second order growth!"
|
"Gamma growth rate is not implemented for second order growth!")
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return _growth_factor_second_ODE(cosmo, a)
|
return _growth_factor_second_ODE(cosmo, a)
|
||||||
|
@ -228,8 +219,7 @@ def growth_rate_second(cosmo, a):
|
||||||
"""
|
"""
|
||||||
if cosmo._flags["gamma_growth"]:
|
if cosmo._flags["gamma_growth"]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Gamma growth factor is not implemented for second order growth!"
|
"Gamma growth factor is not implemented for second order growth!")
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return _growth_rate_second_ODE(cosmo, a)
|
return _growth_rate_second_ODE(cosmo, a)
|
||||||
|
@ -258,23 +248,19 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
|
||||||
atab = np.logspace(log10_amin, 0.0, steps)
|
atab = np.logspace(log10_amin, 0.0, steps)
|
||||||
|
|
||||||
def D_derivs(y, x):
|
def D_derivs(y, x):
|
||||||
q = (
|
q = (2.0 - 0.5 *
|
||||||
2.0
|
(Omega_m_a(cosmo, x) +
|
||||||
- 0.5
|
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
|
||||||
* (
|
|
||||||
Omega_m_a(cosmo, x)
|
|
||||||
+ (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x)
|
|
||||||
)
|
|
||||||
) / x
|
|
||||||
r = 1.5 * Omega_m_a(cosmo, x) / x / x
|
r = 1.5 * Omega_m_a(cosmo, x) / x / x
|
||||||
|
|
||||||
g1, g2 = y[0]
|
g1, g2 = y[0]
|
||||||
f1, f2 = y[1]
|
f1, f2 = y[1]
|
||||||
dy1da = [f1, -q * f1 + r * g1]
|
dy1da = [f1, -q * f1 + r * g1]
|
||||||
dy2da = [f2, -q * f2 + r * g2 - r * g1 ** 2]
|
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
|
||||||
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
|
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
|
||||||
|
|
||||||
y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]])
|
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
|
||||||
|
[1.0, -6.0 / 7 * atab[0]]])
|
||||||
y = odeint(D_derivs, y0, atab)
|
y = odeint(D_derivs, y0, atab)
|
||||||
|
|
||||||
# compute second order derivatives growth
|
# compute second order derivatives growth
|
||||||
|
@ -473,8 +459,7 @@ def _growth_rate_gamma(cosmo, a):
|
||||||
|
|
||||||
see :cite:`2019:Euclid Preparation VII, eqn.32`
|
see :cite:`2019:Euclid Preparation VII, eqn.32`
|
||||||
"""
|
"""
|
||||||
return Omega_m_a(cosmo, a) ** cosmo.gamma
|
return Omega_m_a(cosmo, a)**cosmo.gamma
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def Gf(cosmo, a):
|
def Gf(cosmo, a):
|
||||||
|
@ -503,7 +488,7 @@ def Gf(cosmo, a):
|
||||||
"""
|
"""
|
||||||
f1 = growth_rate(cosmo, a)
|
f1 = growth_rate(cosmo, a)
|
||||||
g1 = growth_factor(cosmo, a)
|
g1 = growth_factor(cosmo, a)
|
||||||
D1f = f1*g1/ a
|
D1f = f1 * g1 / a
|
||||||
return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
|
return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@ -532,7 +517,7 @@ def Gf2(cosmo, a):
|
||||||
"""
|
"""
|
||||||
f2 = growth_rate_second(cosmo, a)
|
f2 = growth_rate_second(cosmo, a)
|
||||||
g2 = growth_factor_second(cosmo, a)
|
g2 = growth_factor_second(cosmo, a)
|
||||||
D2f = f2*g2/ a
|
D2f = f2 * g2 / a
|
||||||
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
|
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@ -563,13 +548,12 @@ def dGfa(cosmo, a):
|
||||||
"""
|
"""
|
||||||
f1 = growth_rate(cosmo, a)
|
f1 = growth_rate(cosmo, a)
|
||||||
g1 = growth_factor(cosmo, a)
|
g1 = growth_factor(cosmo, a)
|
||||||
D1f = f1*g1/ a
|
D1f = f1 * g1 / a
|
||||||
cache = cosmo._workspace['background.growth_factor']
|
cache = cosmo._workspace['background.growth_factor']
|
||||||
f1p = cache['h'] / cache['a'] * cache['g']
|
f1p = cache['h'] / cache['a'] * cache['g']
|
||||||
f1p = interp(np.log(a), np.log(cache['a']), f1p)
|
f1p = interp(np.log(a), np.log(cache['a']), f1p)
|
||||||
Ea = E(cosmo, a)
|
Ea = E(cosmo, a)
|
||||||
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) +
|
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f)
|
||||||
3 * a**2 * Ea * D1f)
|
|
||||||
|
|
||||||
|
|
||||||
def dGf2a(cosmo, a):
|
def dGf2a(cosmo, a):
|
||||||
|
@ -599,10 +583,9 @@ def dGf2a(cosmo, a):
|
||||||
"""
|
"""
|
||||||
f2 = growth_rate_second(cosmo, a)
|
f2 = growth_rate_second(cosmo, a)
|
||||||
g2 = growth_factor_second(cosmo, a)
|
g2 = growth_factor_second(cosmo, a)
|
||||||
D2f = f2*g2/ a
|
D2f = f2 * g2 / a
|
||||||
cache = cosmo._workspace['background.growth_factor']
|
cache = cosmo._workspace['background.growth_factor']
|
||||||
f2p = cache['h2'] / cache['a'] * cache['g2']
|
f2p = cache['h2'] / cache['a'] * cache['g2']
|
||||||
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
||||||
E = E(cosmo, a)
|
E = E(cosmo, a)
|
||||||
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) +
|
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
|
||||||
3 * a**2 * E * D2f)
|
|
||||||
|
|
116
jaxpm/kernels.py
116
jaxpm/kernels.py
|
@ -1,25 +1,27 @@
|
||||||
import numpy as np
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
|
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
|
||||||
""" Return k_vector given a shape (nc, nc, nc) and box_size
|
""" Return k_vector given a shape (nc, nc, nc) and box_size
|
||||||
"""
|
"""
|
||||||
k = []
|
k = []
|
||||||
for d in range(len(shape)):
|
for d in range(len(shape)):
|
||||||
kd = np.fft.fftfreq(shape[d])
|
kd = np.fft.fftfreq(shape[d])
|
||||||
kd *= 2 * np.pi
|
kd *= 2 * np.pi
|
||||||
kdshape = np.ones(len(shape), dtype='int')
|
kdshape = np.ones(len(shape), dtype='int')
|
||||||
if symmetric and d == len(shape) - 1:
|
if symmetric and d == len(shape) - 1:
|
||||||
kd = kd[:shape[d] // 2 + 1]
|
kd = kd[:shape[d] // 2 + 1]
|
||||||
kdshape[d] = len(kd)
|
kdshape[d] = len(kd)
|
||||||
kd = kd.reshape(kdshape)
|
kd = kd.reshape(kdshape)
|
||||||
|
|
||||||
|
k.append(kd.astype(dtype))
|
||||||
|
del kd, kdshape
|
||||||
|
return k
|
||||||
|
|
||||||
k.append(kd.astype(dtype))
|
|
||||||
del kd, kdshape
|
|
||||||
return k
|
|
||||||
|
|
||||||
def gradient_kernel(kvec, direction, order=1):
|
def gradient_kernel(kvec, direction, order=1):
|
||||||
"""
|
"""
|
||||||
Computes the gradient kernel in the requested direction
|
Computes the gradient kernel in the requested direction
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
@ -32,20 +34,21 @@ def gradient_kernel(kvec, direction, order=1):
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel
|
Complex kernel
|
||||||
"""
|
"""
|
||||||
if order == 0:
|
if order == 0:
|
||||||
wts = 1j * kvec[direction]
|
wts = 1j * kvec[direction]
|
||||||
wts = jnp.squeeze(wts)
|
wts = jnp.squeeze(wts)
|
||||||
wts[len(wts) // 2] = 0
|
wts[len(wts) // 2] = 0
|
||||||
wts = wts.reshape(kvec[direction].shape)
|
wts = wts.reshape(kvec[direction].shape)
|
||||||
return wts
|
return wts
|
||||||
else:
|
else:
|
||||||
w = kvec[direction]
|
w = kvec[direction]
|
||||||
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
|
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
|
||||||
wts = a * 1j
|
wts = a * 1j
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
|
||||||
def laplace_kernel(kvec):
|
def laplace_kernel(kvec):
|
||||||
"""
|
"""
|
||||||
Compute the Laplace kernel from a given K vector
|
Compute the Laplace kernel from a given K vector
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
@ -56,16 +59,17 @@ def laplace_kernel(kvec):
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel
|
Complex kernel
|
||||||
"""
|
"""
|
||||||
kk = sum(ki**2 for ki in kvec)
|
kk = sum(ki**2 for ki in kvec)
|
||||||
mask = (kk == 0).nonzero()
|
mask = (kk == 0).nonzero()
|
||||||
kk[mask] = 1
|
kk[mask] = 1
|
||||||
wts = 1. / kk
|
wts = 1. / kk
|
||||||
imask = (~(kk == 0)).astype(int)
|
imask = (~(kk == 0)).astype(int)
|
||||||
wts *= imask
|
wts *= imask
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
|
||||||
def longrange_kernel(kvec, r_split):
|
def longrange_kernel(kvec, r_split):
|
||||||
"""
|
"""
|
||||||
Computes a long range kernel
|
Computes a long range kernel
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
@ -78,29 +82,31 @@ def longrange_kernel(kvec, r_split):
|
||||||
wts: array
|
wts: array
|
||||||
kernel
|
kernel
|
||||||
"""
|
"""
|
||||||
if r_split != 0:
|
if r_split != 0:
|
||||||
kk = sum(ki**2 for ki in kvec)
|
kk = sum(ki**2 for ki in kvec)
|
||||||
return np.exp(-kk * r_split**2)
|
return np.exp(-kk * r_split**2)
|
||||||
else:
|
else:
|
||||||
return 1.
|
return 1.
|
||||||
|
|
||||||
|
|
||||||
def cic_compensation(kvec):
|
def cic_compensation(kvec):
|
||||||
"""
|
"""
|
||||||
Computes cic compensation kernel.
|
Computes cic compensation kernel.
|
||||||
Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499
|
Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499
|
||||||
Itself based on equation 18 (with p=2) of
|
Itself based on equation 18 (with p=2) of
|
||||||
`Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_
|
`Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_
|
||||||
Args:
|
Args:
|
||||||
kvec: array of k values in Fourier space
|
kvec: array of k values in Fourier space
|
||||||
Returns:
|
Returns:
|
||||||
v: array of kernel
|
v: array of kernel
|
||||||
"""
|
"""
|
||||||
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
|
||||||
def PGD_kernel(kvec, kl, ks):
|
def PGD_kernel(kvec, kl, ks):
|
||||||
"""
|
"""
|
||||||
Computes the PGD kernel
|
Computes the PGD kernel
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
@ -115,12 +121,12 @@ def PGD_kernel(kvec, kl, ks):
|
||||||
v: array
|
v: array
|
||||||
kernel
|
kernel
|
||||||
"""
|
"""
|
||||||
kk = sum(ki**2 for ki in kvec)
|
kk = sum(ki**2 for ki in kvec)
|
||||||
kl2 = kl**2
|
kl2 = kl**2
|
||||||
ks4 = ks**4
|
ks4 = ks**4
|
||||||
mask = (kk == 0).nonzero()
|
mask = (kk == 0).nonzero()
|
||||||
kk[mask] = 1
|
kk[mask] = 1
|
||||||
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
|
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
|
||||||
imask = (~(kk == 0)).astype(int)
|
imask = (~(kk == 0)).astype(int)
|
||||||
v *= imask
|
v *= imask
|
||||||
return v
|
return v
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo.constants as constants
|
|
||||||
import jax_cosmo
|
import jax_cosmo
|
||||||
|
import jax_cosmo.constants as constants
|
||||||
from jax.scipy.ndimage import map_coordinates
|
from jax.scipy.ndimage import map_coordinates
|
||||||
from jaxpm.utils import gaussian_smoothing
|
|
||||||
from jaxpm.painting import cic_paint_2d
|
from jaxpm.painting import cic_paint_2d
|
||||||
|
from jaxpm.utils import gaussian_smoothing
|
||||||
|
|
||||||
|
|
||||||
def density_plane(positions,
|
def density_plane(positions,
|
||||||
box_shape,
|
box_shape,
|
||||||
|
@ -26,9 +27,11 @@ def density_plane(positions,
|
||||||
xy = xy / nx * plane_resolution
|
xy = xy / nx * plane_resolution
|
||||||
|
|
||||||
# Selecting only particles that fall inside the volume of interest
|
# Selecting only particles that fall inside the volume of interest
|
||||||
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
|
weight = jnp.where(
|
||||||
|
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
|
||||||
# Painting density plane
|
# Painting density plane
|
||||||
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
|
density_plane = cic_paint_2d(
|
||||||
|
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
|
||||||
|
|
||||||
# Apply density normalization
|
# Apply density normalization
|
||||||
density_plane = density_plane / ((nx / plane_resolution) *
|
density_plane = density_plane / ((nx / plane_resolution) *
|
||||||
|
@ -36,45 +39,44 @@ def density_plane(positions,
|
||||||
|
|
||||||
# Apply Gaussian smoothing if requested
|
# Apply Gaussian smoothing if requested
|
||||||
if smoothing_sigma is not None:
|
if smoothing_sigma is not None:
|
||||||
density_plane = gaussian_smoothing(density_plane,
|
density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
|
||||||
smoothing_sigma)
|
|
||||||
|
|
||||||
return density_plane
|
return density_plane
|
||||||
|
|
||||||
|
|
||||||
def convergence_Born(cosmo,
|
def convergence_Born(cosmo, density_planes, coords, z_source):
|
||||||
density_planes,
|
"""
|
||||||
coords,
|
|
||||||
z_source):
|
|
||||||
"""
|
|
||||||
Compute the Born convergence
|
Compute the Born convergence
|
||||||
Args:
|
Args:
|
||||||
cosmo: `Cosmology`, cosmology object.
|
cosmo: `Cosmology`, cosmology object.
|
||||||
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
|
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
|
||||||
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
|
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
|
||||||
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
|
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
|
||||||
name: `string`, name of the operation.
|
name: `string`, name of the operation.
|
||||||
Returns:
|
Returns:
|
||||||
`Tensor` of shape [batch_size, N, Nz], of convergence values.
|
`Tensor` of shape [batch_size, N, Nz], of convergence values.
|
||||||
"""
|
"""
|
||||||
# Compute constant prefactor:
|
# Compute constant prefactor:
|
||||||
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
|
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
|
||||||
# Compute comoving distance of source galaxies
|
# Compute comoving distance of source galaxies
|
||||||
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
|
r_s = jax_cosmo.background.radial_comoving_distance(
|
||||||
|
cosmo, 1 / (1 + z_source))
|
||||||
|
|
||||||
convergence = 0
|
convergence = 0
|
||||||
for entry in density_planes:
|
for entry in density_planes:
|
||||||
r = entry['r']; a = entry['a']; p = entry['plane']
|
r = entry['r']
|
||||||
dx = entry['dx']; dz = entry['dz']
|
a = entry['a']
|
||||||
# Normalize density planes
|
p = entry['plane']
|
||||||
density_normalization = dz * r / a
|
dx = entry['dx']
|
||||||
p = (p - p.mean()) * constant_factor * density_normalization
|
dz = entry['dz']
|
||||||
|
# Normalize density planes
|
||||||
|
density_normalization = dz * r / a
|
||||||
|
p = (p - p.mean()) * constant_factor * density_normalization
|
||||||
|
|
||||||
# Interpolate at the density plane coordinates
|
# Interpolate at the density plane coordinates
|
||||||
im = map_coordinates(p,
|
im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
|
||||||
coords * r / dx - 0.5,
|
|
||||||
order=1, mode="wrap")
|
|
||||||
|
|
||||||
convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1])
|
convergence += im * jnp.clip(1. -
|
||||||
|
(r / r_s), 0, 1000).reshape([-1, 1, 1])
|
||||||
|
|
||||||
return convergence
|
return convergence
|
||||||
|
|
60
jaxpm/nn.py
60
jaxpm/nn.py
|
@ -1,6 +1,7 @@
|
||||||
|
import haiku as hk
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import haiku as hk
|
|
||||||
|
|
||||||
def _deBoorVectorized(x, t, c, p):
|
def _deBoorVectorized(x, t, c, p):
|
||||||
"""
|
"""
|
||||||
|
@ -13,48 +14,47 @@ def _deBoorVectorized(x, t, c, p):
|
||||||
c: array of control points
|
c: array of control points
|
||||||
p: degree of B-spline
|
p: degree of B-spline
|
||||||
"""
|
"""
|
||||||
k = jnp.digitize(x, t) -1
|
k = jnp.digitize(x, t) - 1
|
||||||
|
|
||||||
d = [c[j + k - p] for j in range(0, p+1)]
|
d = [c[j + k - p] for j in range(0, p + 1)]
|
||||||
for r in range(1, p+1):
|
for r in range(1, p + 1):
|
||||||
for j in range(p, r-1, -1):
|
for j in range(p, r - 1, -1):
|
||||||
alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p])
|
alpha = (x - t[j + k - p]) / (t[j + 1 + k - r] - t[j + k - p])
|
||||||
d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j]
|
d[j] = (1.0 - alpha) * d[j - 1] + alpha * d[j]
|
||||||
return d[p]
|
return d[p]
|
||||||
|
|
||||||
|
|
||||||
class NeuralSplineFourierFilter(hk.Module):
|
class NeuralSplineFourierFilter(hk.Module):
|
||||||
"""A rotationally invariant filter parameterized by
|
"""A rotationally invariant filter parameterized by
|
||||||
a b-spline with parameters specified by a small NN."""
|
a b-spline with parameters specified by a small NN."""
|
||||||
|
|
||||||
def __init__(self, n_knots=8, latent_size=16, name=None):
|
def __init__(self, n_knots=8, latent_size=16, name=None):
|
||||||
|
"""
|
||||||
|
n_knots: number of control points for the spline
|
||||||
"""
|
"""
|
||||||
n_knots: number of control points for the spline
|
super().__init__(name=name)
|
||||||
"""
|
self.n_knots = n_knots
|
||||||
super().__init__(name=name)
|
self.latent_size = latent_size
|
||||||
self.n_knots = n_knots
|
|
||||||
self.latent_size = latent_size
|
|
||||||
|
|
||||||
def __call__(self, x, a):
|
def __call__(self, x, a):
|
||||||
"""
|
"""
|
||||||
x: array, scale, normalized to fftfreq default
|
x: array, scale, normalized to fftfreq default
|
||||||
a: scalar, scale factor
|
a: scalar, scale factor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
|
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
|
||||||
net = jnp.sin(hk.Linear(self.latent_size)(net))
|
net = jnp.sin(hk.Linear(self.latent_size)(net))
|
||||||
|
|
||||||
w = hk.Linear(self.n_knots+1)(net)
|
w = hk.Linear(self.n_knots + 1)(net)
|
||||||
k = hk.Linear(self.n_knots-1)(net)
|
k = hk.Linear(self.n_knots - 1)(net)
|
||||||
|
|
||||||
# make sure the knots sum to 1 and are in the interval 0,1
|
|
||||||
k = jnp.concatenate([jnp.zeros((1,)),
|
|
||||||
jnp.cumsum(jax.nn.softmax(k))])
|
|
||||||
|
|
||||||
w = jnp.concatenate([jnp.zeros((1,)),
|
# make sure the knots sum to 1 and are in the interval 0,1
|
||||||
w])
|
k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))])
|
||||||
|
|
||||||
# Augment with repeating points
|
w = jnp.concatenate([jnp.zeros((1, )), w])
|
||||||
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
|
|
||||||
|
|
||||||
return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
|
# Augment with repeating points
|
||||||
|
ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))])
|
||||||
|
|
||||||
|
return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w,
|
||||||
|
3)
|
||||||
|
|
|
@ -1,96 +1,100 @@
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
|
||||||
import jax.lax as lax
|
import jax.lax as lax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from jaxpm.kernels import fftk, cic_compensation
|
from jaxpm.kernels import cic_compensation, fftk
|
||||||
|
|
||||||
def cic_paint(mesh, positions):
|
|
||||||
""" Paints positions onto mesh
|
def cic_paint(mesh, positions, weight=None):
|
||||||
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
positions: [npart, 3]
|
||||||
"""
|
"""
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||||
[0., 1, 1], [1., 1, 1]]])
|
|
||||||
|
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
if weight is not None:
|
||||||
|
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
|
neighboor_coords = jnp.mod(
|
||||||
|
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||||
|
jnp.array(mesh.shape))
|
||||||
|
|
||||||
|
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||||
|
inserted_window_dims=(0, 1, 2),
|
||||||
|
scatter_dims_to_operand_dims=(0, 1,
|
||||||
|
2))
|
||||||
|
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
|
||||||
|
dnums)
|
||||||
|
return mesh
|
||||||
|
|
||||||
dnums = jax.lax.ScatterDimensionNumbers(
|
|
||||||
update_window_dims=(),
|
|
||||||
inserted_window_dims=(0, 1, 2),
|
|
||||||
scatter_dims_to_operand_dims=(0, 1, 2))
|
|
||||||
mesh = lax.scatter_add(mesh,
|
|
||||||
neighboor_coords,
|
|
||||||
kernel.reshape([-1,8]),
|
|
||||||
dnums)
|
|
||||||
return mesh
|
|
||||||
|
|
||||||
def cic_read(mesh, positions):
|
def cic_read(mesh, positions):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
positions: [npart, 3]
|
||||||
"""
|
"""
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||||
[0., 1, 1], [1., 1, 1]]])
|
|
||||||
|
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape))
|
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
||||||
|
jnp.array(mesh.shape))
|
||||||
|
|
||||||
|
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||||
|
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
||||||
|
|
||||||
return (mesh[neighboor_coords[...,0],
|
|
||||||
neighboor_coords[...,1],
|
|
||||||
neighboor_coords[...,3]]*kernel).sum(axis=-1)
|
|
||||||
|
|
||||||
def cic_paint_2d(mesh, positions, weight):
|
def cic_paint_2d(mesh, positions, weight):
|
||||||
""" Paints positions onto a 2d mesh
|
""" Paints positions onto a 2d mesh
|
||||||
mesh: [nx, ny]
|
mesh: [nx, ny]
|
||||||
positions: [npart, 2]
|
positions: [npart, 2]
|
||||||
weight: [npart]
|
weight: [npart]
|
||||||
"""
|
"""
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||||
|
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1]
|
kernel = kernel[..., 0] * kernel[..., 1]
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
kernel = kernel * weight[...,jnp.newaxis]
|
kernel = kernel * weight[..., jnp.newaxis]
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
|
neighboor_coords = jnp.mod(
|
||||||
|
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
||||||
|
jnp.array(mesh.shape))
|
||||||
|
|
||||||
|
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||||
|
inserted_window_dims=(0, 1),
|
||||||
|
scatter_dims_to_operand_dims=(0,
|
||||||
|
1))
|
||||||
|
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
|
||||||
|
dnums)
|
||||||
|
return mesh
|
||||||
|
|
||||||
dnums = jax.lax.ScatterDimensionNumbers(
|
|
||||||
update_window_dims=(),
|
|
||||||
inserted_window_dims=(0, 1),
|
|
||||||
scatter_dims_to_operand_dims=(0, 1))
|
|
||||||
mesh = lax.scatter_add(mesh,
|
|
||||||
neighboor_coords,
|
|
||||||
kernel.reshape([-1,4]),
|
|
||||||
dnums)
|
|
||||||
return mesh
|
|
||||||
|
|
||||||
def compensate_cic(field):
|
def compensate_cic(field):
|
||||||
"""
|
"""
|
||||||
Compensate for CiC painting
|
Compensate for CiC painting
|
||||||
Args:
|
Args:
|
||||||
field: input 3D cic-painted field
|
field: input 3D cic-painted field
|
||||||
Returns:
|
Returns:
|
||||||
compensated_field
|
compensated_field
|
||||||
"""
|
"""
|
||||||
nc = field.shape
|
nc = field.shape
|
||||||
kvec = fftk(nc)
|
kvec = fftk(nc)
|
||||||
|
|
||||||
delta_k = jnp.fft.rfftn(field)
|
delta_k = jnp.fft.rfftn(field)
|
||||||
delta_k = cic_compensation(kvec) * delta_k
|
delta_k = cic_compensation(kvec) * delta_k
|
||||||
return jnp.fft.irfftn(delta_k)
|
return jnp.fft.irfftn(delta_k)
|
||||||
|
|
38
jaxpm/pm.py
38
jaxpm/pm.py
|
@ -1,11 +1,12 @@
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
|
|
||||||
from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel, PGD_kernel
|
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
||||||
|
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
|
||||||
|
longrange_kernel)
|
||||||
from jaxpm.painting import cic_paint, cic_read
|
from jaxpm.painting import cic_paint, cic_read
|
||||||
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
|
||||||
|
|
||||||
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
||||||
"""
|
"""
|
||||||
|
@ -21,10 +22,14 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
||||||
delta_k = jnp.fft.rfftn(delta)
|
delta_k = jnp.fft.rfftn(delta)
|
||||||
|
|
||||||
# Computes gravitational potential
|
# Computes gravitational potential
|
||||||
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
|
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
|
||||||
|
r_split=r_split)
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions)
|
return jnp.stack([
|
||||||
for i in range(3)],axis=-1)
|
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions)
|
||||||
|
for i in range(3)
|
||||||
|
],
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def lpt(cosmo, initial_conditions, positions, a):
|
def lpt(cosmo, initial_conditions, positions, a):
|
||||||
|
@ -34,25 +39,31 @@ def lpt(cosmo, initial_conditions, positions, a):
|
||||||
initial_force = pm_forces(positions, delta=initial_conditions)
|
initial_force = pm_forces(positions, delta=initial_conditions)
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
dx = growth_factor(cosmo, a) * initial_force
|
dx = growth_factor(cosmo, a) * initial_force
|
||||||
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
|
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
|
||||||
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force
|
a)) * dx
|
||||||
|
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo,
|
||||||
|
a) * initial_force
|
||||||
return dx, p, f
|
return dx, p, f
|
||||||
|
|
||||||
|
|
||||||
def linear_field(mesh_shape, box_size, pk, seed):
|
def linear_field(mesh_shape, box_size, pk, seed):
|
||||||
"""
|
"""
|
||||||
Generate initial conditions.
|
Generate initial conditions.
|
||||||
"""
|
"""
|
||||||
kvec = fftk(mesh_shape)
|
kvec = fftk(mesh_shape)
|
||||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
|
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
|
for i, kk in enumerate(kvec))**0.5
|
||||||
|
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
||||||
|
box_size[0] * box_size[1] * box_size[2])
|
||||||
|
|
||||||
field = jax.random.normal(seed, mesh_shape)
|
field = jax.random.normal(seed, mesh_shape)
|
||||||
field = jnp.fft.rfftn(field) * pkmesh**0.5
|
field = jnp.fft.rfftn(field) * pkmesh**0.5
|
||||||
field = jnp.fft.irfftn(field)
|
field = jnp.fft.irfftn(field)
|
||||||
return field
|
return field
|
||||||
|
|
||||||
|
|
||||||
def make_ode_fn(mesh_shape):
|
def make_ode_fn(mesh_shape):
|
||||||
|
|
||||||
def nbody_ode(state, a, cosmo):
|
def nbody_ode(state, a, cosmo):
|
||||||
"""
|
"""
|
||||||
state is a tuple (position, velocities)
|
state is a tuple (position, velocities)
|
||||||
|
@ -63,10 +74,10 @@ def make_ode_fn(mesh_shape):
|
||||||
|
|
||||||
# Computes the update of position (drift)
|
# Computes the update of position (drift)
|
||||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||||
|
|
||||||
# Computes the update of velocity (kick)
|
# Computes the update of velocity (kick)
|
||||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
return dpos, dvel
|
return dpos, dvel
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
@ -128,4 +139,3 @@ def make_neural_ode_fn(model, mesh_shape):
|
||||||
|
|
||||||
return dpos, dvel
|
return dpos, dvel
|
||||||
return neural_nbody_ode
|
return neural_nbody_ode
|
||||||
|
|
||||||
|
|
117
jaxpm/utils.py
117
jaxpm/utils.py
|
@ -1,85 +1,87 @@
|
||||||
import numpy as np
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
from jax.scipy.stats import norm
|
from jax.scipy.stats import norm
|
||||||
|
|
||||||
__all__ = ['power_spectrum']
|
__all__ = ['power_spectrum']
|
||||||
|
|
||||||
|
|
||||||
def _initialize_pk(shape, boxsize, kmin, dk):
|
def _initialize_pk(shape, boxsize, kmin, dk):
|
||||||
"""
|
"""
|
||||||
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
|
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
|
||||||
"""
|
"""
|
||||||
I = np.eye(len(shape), dtype='int') * -2 + 1
|
I = np.eye(len(shape), dtype='int') * -2 + 1
|
||||||
|
|
||||||
W = np.empty(shape, dtype='f4')
|
W = np.empty(shape, dtype='f4')
|
||||||
W[...] = 2.0
|
W[...] = 2.0
|
||||||
W[..., 0] = 1.0
|
W[..., 0] = 1.0
|
||||||
W[..., -1] = 1.0
|
W[..., -1] = 1.0
|
||||||
|
|
||||||
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
|
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
|
||||||
kedges = np.arange(kmin, kmax, dk)
|
kedges = np.arange(kmin, kmax, dk)
|
||||||
|
|
||||||
k = [
|
k = [
|
||||||
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
|
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
|
||||||
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
|
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
|
||||||
]
|
]
|
||||||
kmag = sum(ki**2 for ki in k)**0.5
|
kmag = sum(ki**2 for ki in k)**0.5
|
||||||
|
|
||||||
xsum = np.zeros(len(kedges) + 1)
|
xsum = np.zeros(len(kedges) + 1)
|
||||||
Nsum = np.zeros(len(kedges) + 1)
|
Nsum = np.zeros(len(kedges) + 1)
|
||||||
|
|
||||||
dig = np.digitize(kmag.flat, kedges)
|
dig = np.digitize(kmag.flat, kedges)
|
||||||
|
|
||||||
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
|
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
|
||||||
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
|
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
|
||||||
return dig, Nsum, xsum, W, k, kedges
|
return dig, Nsum, xsum, W, k, kedges
|
||||||
|
|
||||||
|
|
||||||
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
||||||
"""
|
"""
|
||||||
Calculate the powerspectra given real space field
|
Calculate the powerspectra given real space field
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
||||||
field: real valued field
|
field: real valued field
|
||||||
kmin: minimum k-value for binned powerspectra
|
kmin: minimum k-value for binned powerspectra
|
||||||
dk: differential in each kbin
|
dk: differential in each kbin
|
||||||
boxsize: length of each boxlength (can be strangly shaped?)
|
boxsize: length of each boxlength (can be strangly shaped?)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
kbins: the central value of the bins for plotting
|
kbins: the central value of the bins for plotting
|
||||||
power: real valued array of power in each bin
|
power: real valued array of power in each bin
|
||||||
|
|
||||||
"""
|
"""
|
||||||
shape = field.shape
|
shape = field.shape
|
||||||
nx, ny, nz = shape
|
nx, ny, nz = shape
|
||||||
|
|
||||||
#initialze values related to powerspectra (mode bins and weights)
|
#initialze values related to powerspectra (mode bins and weights)
|
||||||
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
||||||
|
|
||||||
#fast fourier transform
|
#fast fourier transform
|
||||||
fft_image = jnp.fft.fftn(field)
|
fft_image = jnp.fft.fftn(field)
|
||||||
|
|
||||||
#absolute value of fast fourier transform
|
#absolute value of fast fourier transform
|
||||||
pk = jnp.real(fft_image * jnp.conj(fft_image))
|
pk = jnp.real(fft_image * jnp.conj(fft_image))
|
||||||
|
|
||||||
|
#calculating powerspectra
|
||||||
|
real = jnp.real(pk).reshape([-1])
|
||||||
|
imag = jnp.imag(pk).reshape([-1])
|
||||||
|
|
||||||
#calculating powerspectra
|
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
|
||||||
real = jnp.real(pk).reshape([-1])
|
length=xsum.size) * 1j
|
||||||
imag = jnp.imag(pk).reshape([-1])
|
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
||||||
|
|
||||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
|
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
||||||
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
|
||||||
|
|
||||||
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
#normalization for powerspectra
|
||||||
|
norm = np.prod(np.array(shape[:])).astype('float32')**2
|
||||||
|
|
||||||
#normalization for powerspectra
|
#find central values of each bin
|
||||||
norm = np.prod(np.array(shape[:])).astype('float32')**2
|
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
||||||
|
|
||||||
#find central values of each bin
|
return kbins, P / norm
|
||||||
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
|
||||||
|
|
||||||
return kbins, P / norm
|
|
||||||
|
|
||||||
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
|
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
|
||||||
"""
|
"""
|
||||||
|
@ -131,18 +133,17 @@ def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=Fals
|
||||||
|
|
||||||
|
|
||||||
def gaussian_smoothing(im, sigma):
|
def gaussian_smoothing(im, sigma):
|
||||||
"""
|
"""
|
||||||
im: 2d image
|
im: 2d image
|
||||||
sigma: smoothing scale in px
|
sigma: smoothing scale in px
|
||||||
"""
|
"""
|
||||||
# Compute k vector
|
# Compute k vector
|
||||||
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
|
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
|
||||||
jnp.fft.fftfreq(im.shape[1])),
|
jnp.fft.fftfreq(im.shape[1])),
|
||||||
axis=-1)
|
axis=-1)
|
||||||
k = jnp.linalg.norm(kvec, axis=-1)
|
k = jnp.linalg.norm(kvec, axis=-1)
|
||||||
# We compute the value of the filter at frequency k
|
# We compute the value of the filter at frequency k
|
||||||
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
|
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
|
||||||
filter /= filter[0,0]
|
filter /= filter[0, 0]
|
||||||
|
|
||||||
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
|
|
||||||
|
|
||||||
|
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
|
||||||
|
|
6
setup.py
6
setup.py
|
@ -1,4 +1,4 @@
|
||||||
from setuptools import setup, find_packages
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='JaxPM',
|
name='JaxPM',
|
||||||
|
@ -6,6 +6,6 @@ setup(
|
||||||
url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
|
url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
|
||||||
author='JaxPM developers',
|
author='JaxPM developers',
|
||||||
description='A dead simple FastPM implementation in JAX',
|
description='A dead simple FastPM implementation in JAX',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=['jax', 'jax_cosmo'],
|
install_requires=['jax', 'jax_cosmo'],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue