forked from Aquila-Consortium/JaxPM_highres
Applying formatting
This commit is contained in:
parent
835fa89aec
commit
f28442bb48
14 changed files with 565 additions and 445 deletions
17
.pre-commit-config.yaml
Normal file
17
.pre-commit-config.yaml
Normal file
|
@ -0,0 +1,17 @@
|
|||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.3.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.40.2
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: ['--parallel', '--in-place']
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
|
@ -1,57 +1,80 @@
|
|||
# Can be executed with:
|
||||
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import jax.lax as lax
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.experimental.maps import Mesh, xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
from functools import partial
|
||||
|
||||
jax.distributed.initialize()
|
||||
|
||||
cube_size = 2048
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=[...],
|
||||
out_axes=['x','y', ...],
|
||||
axis_sizes={'x':cube_size, 'y':cube_size},
|
||||
axis_resources={'x': 'nx', 'y':'ny',
|
||||
'key_x':'nx', 'key_y':'ny'})
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_sizes={
|
||||
'x': cube_size,
|
||||
'y': cube_size
|
||||
},
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny',
|
||||
'key_x': 'nx',
|
||||
'key_y': 'ny'
|
||||
})
|
||||
def pnormal(key):
|
||||
return jax.random.normal(key, shape=[cube_size])
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes={0:'x', 1:'y'},
|
||||
out_axes=['x','y', ...],
|
||||
axis_resources={'x': 'nx', 'y': 'ny'})
|
||||
in_axes={
|
||||
0: 'x',
|
||||
1: 'y'
|
||||
},
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny'
|
||||
})
|
||||
@jax.jit
|
||||
def pfft3d(mesh):
|
||||
# [x, y, z]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on z
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on y
|
||||
mesh = jnp.fft.fft(mesh) # Transform on z
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on y
|
||||
# [z, x, y]
|
||||
return mesh
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes={0:'x', 1:'y'},
|
||||
out_axes=['x','y', ...],
|
||||
axis_resources={'x': 'nx', 'y': 'ny'})
|
||||
in_axes={
|
||||
0: 'x',
|
||||
1: 'y'
|
||||
},
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny'
|
||||
})
|
||||
@jax.jit
|
||||
def pifft3d(mesh):
|
||||
# [z, x, y]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on y
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on z
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on y
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on z
|
||||
# [x, y, z]
|
||||
return mesh
|
||||
|
||||
|
||||
key = jax.random.PRNGKey(42)
|
||||
# keys = jax.random.split(key, 4).reshape((2,2,2))
|
||||
|
||||
|
|
|
@ -1,44 +1,49 @@
|
|||
# Start this script with:
|
||||
# mpirun -np 4 python test_script.py
|
||||
import os
|
||||
|
||||
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
|
||||
import matplotlib.pylab as plt
|
||||
import jax
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
import tensorflow_probability as tfp
|
||||
from jax.experimental.maps import mesh, xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
|
||||
|
||||
tfp = tfp.substrates.jax
|
||||
tfd = tfp.distributions
|
||||
|
||||
|
||||
def cic_paint(mesh, positions):
|
||||
""" Paints positions onto mesh
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||
[0., 1, 1], [1., 1, 1]]])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1,
|
||||
2))
|
||||
mesh = lax.scatter_add(
|
||||
mesh,
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
kernel.reshape([-1, 8]), dnums)
|
||||
return mesh
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1, 2))
|
||||
mesh = lax.scatter_add(mesh,
|
||||
neighboor_coords.reshape([-1,8,3]).astype('int32'),
|
||||
kernel.reshape([-1,8]),
|
||||
dnums)
|
||||
return mesh
|
||||
|
||||
# And let's draw some points from some 3D distribution
|
||||
dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.)
|
||||
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
|
||||
scale_identity_multiplier=3.)
|
||||
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
|
||||
|
||||
f = pjit(lambda x: cic_paint(x, pos),
|
||||
|
@ -51,13 +56,13 @@ devices = np.array(jax.devices()).reshape((2, 2, 1))
|
|||
m = jnp.zeros([32, 32, 32])
|
||||
|
||||
with mesh(devices, ('x', 'y', 'z')):
|
||||
# Shard the mesh, I'm not sure this is absolutely necessary
|
||||
m = pjit(lambda x: x,
|
||||
in_axis_resources=None,
|
||||
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
|
||||
# Shard the mesh, I'm not sure this is absolutely necessary
|
||||
m = pjit(lambda x: x,
|
||||
in_axis_resources=None,
|
||||
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
|
||||
|
||||
# Apply the sharded CiC function
|
||||
res = f(m)
|
||||
# Apply the sharded CiC function
|
||||
res = f(m)
|
||||
|
||||
plt.imshow(res.sum(axis=2))
|
||||
plt.show()
|
|
@ -1,11 +1,12 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.lax as lax
|
||||
from functools import partial
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit, PartitionSpec
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
|
||||
import jaxpm.painting as paint
|
||||
|
||||
# TODO: add a way to configure axis resources from command line
|
||||
|
@ -14,35 +15,59 @@ mesh_size = {'nx': 2, 'ny': 2}
|
|||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=({0: 'x', 2: 'y'},
|
||||
{0: 'x', 2: 'y'},
|
||||
{0: 'x', 2: 'y'}),
|
||||
out_axes=({0: 'x', 2: 'y'}),
|
||||
in_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}, {
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}, {
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
out_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
axis_resources=axis_resources)
|
||||
def stack3d(a, b, c):
|
||||
return jnp.stack([a, b, c], axis=-1)
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=({0: 'x', 2: 'y'},[...]),
|
||||
out_axes=({0: 'x', 2: 'y'}),
|
||||
in_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}, [...]),
|
||||
out_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
axis_resources=axis_resources)
|
||||
def scalar_multiply(a, factor):
|
||||
return a * factor
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=({0: 'x', 2: 'y'},
|
||||
{0: 'x', 2: 'y'}),
|
||||
out_axes=({0: 'x', 2: 'y'}),
|
||||
in_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}, {
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
out_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
axis_resources=axis_resources)
|
||||
def add(a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=['x', 'y',...],
|
||||
out_axes=['x', 'y',...],
|
||||
in_axes=['x', 'y', ...],
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources=axis_resources)
|
||||
def fft3d(mesh):
|
||||
""" Performs a 3D complex Fourier transform
|
||||
|
@ -62,8 +87,8 @@ def fft3d(mesh):
|
|||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=['x', 'y',...],
|
||||
out_axes=['x', 'y',...],
|
||||
in_axes=['x', 'y', ...],
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources=axis_resources)
|
||||
def ifft3d(mesh):
|
||||
mesh = jnp.fft.ifft(mesh)
|
||||
|
@ -72,10 +97,15 @@ def ifft3d(mesh):
|
|||
mesh = lax.all_to_all(mesh, 'x', 0, 0)
|
||||
return jnp.fft.ifft(mesh).real
|
||||
|
||||
|
||||
def normal(key, shape=[]):
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=['x', 'y',...],
|
||||
out_axes={0: 'x', 2: 'y'},
|
||||
in_axes=['x', 'y', ...],
|
||||
out_axes={
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
},
|
||||
axis_resources=axis_resources)
|
||||
def fn(key):
|
||||
""" Generate a distributed random normal distributions
|
||||
|
@ -83,99 +113,126 @@ def normal(key, shape=[]):
|
|||
key: array of random keys with same layout as computational mesh
|
||||
shape: logical shape of array to sample
|
||||
"""
|
||||
return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'],
|
||||
shape[1]//mesh_size['ny']]+shape[2:])
|
||||
return jax.random.normal(
|
||||
key,
|
||||
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
|
||||
shape[2:])
|
||||
|
||||
return fn(key)
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=(['x', 'y', ...],
|
||||
[['x'], ['y'], [...]], [...], [...]),
|
||||
in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources=axis_resources)
|
||||
@jax.jit
|
||||
def scale_by_power_spectrum(kfield, kvec, k, pk):
|
||||
kx, ky, kz = kvec
|
||||
kk = jnp.sqrt(kx**2 + ky ** 2 + kz**2)
|
||||
kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
|
||||
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=(['x', 'y', 'z'],
|
||||
[['x'], ['y'], ['z']]),
|
||||
in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
|
||||
out_axes=(['x', 'y', 'z']),
|
||||
axis_resources=axis_resources)
|
||||
def gradient_laplace_kernel(kfield, kvec):
|
||||
kx, ky, kz = kvec
|
||||
kk = (kx**2 + ky**2 + kz**2)
|
||||
kernel = jnp.where(kk == 0, 1., 1./kk)
|
||||
return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
|
||||
kfield * kernel * 1j * 1 / 6.0 *
|
||||
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
||||
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
|
||||
kernel = jnp.where(kk == 0, 1., 1. / kk)
|
||||
return (kfield * kernel * 1j * 1 / 6.0 *
|
||||
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
|
||||
6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
|
||||
1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=([...]),
|
||||
out_axes={0: 'x', 2: 'y'},
|
||||
axis_sizes={'x': mesh_size['nx'],
|
||||
'y': mesh_size['ny']},
|
||||
out_axes={
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
},
|
||||
axis_sizes={
|
||||
'x': mesh_size['nx'],
|
||||
'y': mesh_size['ny']
|
||||
},
|
||||
axis_resources=axis_resources)
|
||||
def meshgrid(x, y, z):
|
||||
""" Generates a mesh grid of appropriate size for the
|
||||
computational mesh we have.
|
||||
"""
|
||||
return jnp.stack(jnp.meshgrid(x,
|
||||
y,
|
||||
z), axis=-1)
|
||||
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
|
||||
|
||||
|
||||
def cic_paint(pos, mesh_shape, halo_size=0):
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=({0: 'x', 2: 'y'}),
|
||||
out_axes=({0: 'x', 2: 'y'}),
|
||||
in_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
out_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
axis_resources=axis_resources)
|
||||
def fn(pos):
|
||||
|
||||
mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size,
|
||||
mesh_shape[1]//mesh_size['ny']+2*halo_size]
|
||||
+ mesh_shape[2:])
|
||||
mesh = jnp.zeros([
|
||||
mesh_shape[0] // mesh_size['nx'] +
|
||||
2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
|
||||
] + mesh_shape[2:])
|
||||
|
||||
# Paint particles
|
||||
mesh = paint.cic_paint(mesh, pos.reshape(-1, 3) +
|
||||
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||
mesh = paint.cic_paint(
|
||||
mesh,
|
||||
pos.reshape(-1, 3) +
|
||||
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||
|
||||
# Perform halo exchange
|
||||
# Halo exchange along x
|
||||
left = lax.pshuffle(mesh[-2*halo_size:],
|
||||
left = lax.pshuffle(mesh[-2 * halo_size:],
|
||||
perm=range(mesh_size['nx'])[::-1],
|
||||
axis_name='x')
|
||||
right = lax.pshuffle(mesh[:2*halo_size],
|
||||
right = lax.pshuffle(mesh[:2 * halo_size],
|
||||
perm=range(mesh_size['nx'])[::-1],
|
||||
axis_name='x')
|
||||
mesh = mesh.at[:2*halo_size].add(left)
|
||||
mesh = mesh.at[-2*halo_size:].add(right)
|
||||
mesh = mesh.at[:2 * halo_size].add(left)
|
||||
mesh = mesh.at[-2 * halo_size:].add(right)
|
||||
|
||||
# Halo exchange along y
|
||||
left = lax.pshuffle(mesh[:, -2*halo_size:],
|
||||
left = lax.pshuffle(mesh[:, -2 * halo_size:],
|
||||
perm=range(mesh_size['ny'])[::-1],
|
||||
axis_name='y')
|
||||
right = lax.pshuffle(mesh[:, :2*halo_size],
|
||||
right = lax.pshuffle(mesh[:, :2 * halo_size],
|
||||
perm=range(mesh_size['ny'])[::-1],
|
||||
axis_name='y')
|
||||
mesh = mesh.at[:, :2*halo_size].add(left)
|
||||
mesh = mesh.at[:, -2*halo_size:].add(right)
|
||||
mesh = mesh.at[:, :2 * halo_size].add(left)
|
||||
mesh = mesh.at[:, -2 * halo_size:].add(right)
|
||||
|
||||
# removing halo and returning mesh
|
||||
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||
|
||||
return fn(pos)
|
||||
|
||||
|
||||
def cic_read(mesh, pos, halo_size=0):
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=({0: 'x', 2: 'y'},
|
||||
{0: 'x', 2: 'y'},),
|
||||
out_axes=({0: 'x', 2: 'y'}),
|
||||
in_axes=(
|
||||
{
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
},
|
||||
{
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
},
|
||||
),
|
||||
out_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
axis_resources=axis_resources)
|
||||
def fn(mesh, pos):
|
||||
|
||||
|
@ -198,8 +255,10 @@ def cic_read(mesh, pos, halo_size=0):
|
|||
mesh = jnp.concatenate([left, mesh, right], axis=1)
|
||||
|
||||
# Reading field at particles positions
|
||||
res = paint.cic_read(mesh, pos.reshape(-1, 3) +
|
||||
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||
res = paint.cic_read(
|
||||
mesh,
|
||||
pos.reshape(-1, 3) +
|
||||
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||
|
||||
return res.reshape(pos.shape[:-1])
|
||||
|
||||
|
@ -215,8 +274,10 @@ def reshape_dense_to_split(x):
|
|||
data should be necessary
|
||||
"""
|
||||
shape = list(x.shape)
|
||||
return x.reshape([mesh_size['nx'], shape[0]//mesh_size['nx'],
|
||||
mesh_size['ny'], shape[2]//mesh_size['ny']] + shape[2:])
|
||||
return x.reshape([
|
||||
mesh_size['nx'], shape[0] //
|
||||
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
|
||||
] + shape[2:])
|
||||
|
||||
|
||||
@partial(pjit,
|
||||
|
@ -228,4 +289,4 @@ def reshape_split_to_dense(x):
|
|||
data should be necessary
|
||||
"""
|
||||
shape = list(x.shape)
|
||||
return x.reshape([shape[0]*shape[1], shape[2]*shape[3]] + shape[4:])
|
||||
return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import jax
|
||||
from jax.lax import linear_solve_p
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.maps import xmap
|
||||
from functools import partial
|
||||
import jax_cosmo as jc
|
||||
|
||||
from jaxpm.kernels import fftk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.lax import linear_solve_p
|
||||
|
||||
import jaxpm.experimental.distributed_ops as dops
|
||||
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
||||
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
||||
from jaxpm.kernels import fftk
|
||||
|
||||
|
||||
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
|
||||
|
@ -25,8 +26,10 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
|
|||
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
|
||||
|
||||
# Recovers forces at particle positions
|
||||
forces = [dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)),
|
||||
positions, halo_size) for f in forces_k]
|
||||
forces = [
|
||||
dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
|
||||
halo_size) for f in forces_k
|
||||
]
|
||||
|
||||
return dops.stack3d(*forces)
|
||||
|
||||
|
@ -44,12 +47,14 @@ def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
|
|||
field = dops.fft3d(dops.reshape_split_to_dense(field))
|
||||
|
||||
# Rescaling k to physical units
|
||||
kvec = [k.squeeze() / box_size[i] * mesh_shape[i]
|
||||
for i, k in enumerate(fftk(mesh_shape, symmetric=False))]
|
||||
kvec = [
|
||||
k.squeeze() / box_size[i] * mesh_shape[i]
|
||||
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
|
||||
]
|
||||
k = jnp.logspace(-4, 2, 256)
|
||||
pk = jc.power.linear_matter_power(cosmo, k)
|
||||
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
|
||||
) / (box_size[0] * box_size[1] * box_size[2])
|
||||
pk = pk * (mesh_shape[0] * mesh_shape[1] *
|
||||
mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
|
||||
|
||||
|
@ -66,8 +71,9 @@ def lpt(cosmo, initial_conditions, positions, a):
|
|||
initial_force = pm_forces(positions, delta_k=initial_conditions)
|
||||
a = jnp.atleast_1d(a)
|
||||
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
|
||||
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
|
||||
jnp.sqrt(jc.background.Esqr(cosmo, a)))
|
||||
p = dops.scalar_multiply(
|
||||
dx,
|
||||
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
|
||||
return dx, p
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import jax.numpy as np
|
||||
|
||||
from jax_cosmo.background import *
|
||||
from jax_cosmo.scipy.interpolate import interp
|
||||
from jax_cosmo.scipy.ode import odeint
|
||||
from jax_cosmo.background import *
|
||||
|
||||
|
||||
def E(cosmo, a):
|
||||
r"""Scale factor dependent factor E(a) in the Hubble
|
||||
|
@ -52,12 +52,8 @@ def df_de(cosmo, a, epsilon=1e-5):
|
|||
\frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)-
|
||||
\frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)}
|
||||
"""
|
||||
return (
|
||||
3
|
||||
* cosmo.wa
|
||||
* (np.log(a - epsilon) - (a - 1) / (a - epsilon))
|
||||
/ np.power(np.log(a - epsilon), 2)
|
||||
)
|
||||
return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) /
|
||||
np.power(np.log(a - epsilon), 2))
|
||||
|
||||
|
||||
def dEa(cosmo, a):
|
||||
|
@ -89,15 +85,11 @@ def dEa(cosmo, a):
|
|||
where :math:`f(a)` is the Dark Energy evolution parameter computed
|
||||
by :py:meth:`.f_de`.
|
||||
"""
|
||||
return (
|
||||
0.5
|
||||
* (
|
||||
-3 * cosmo.Omega_m * np.power(a, -4)
|
||||
- 2 * cosmo.Omega_k * np.power(a, -3)
|
||||
+ df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))
|
||||
)
|
||||
/ np.power(Esqr(cosmo, a), 0.5)
|
||||
)
|
||||
return (0.5 *
|
||||
(-3 * cosmo.Omega_m * np.power(a, -4) -
|
||||
2 * cosmo.Omega_k * np.power(a, -3) +
|
||||
df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) /
|
||||
np.power(Esqr(cosmo, a), 0.5))
|
||||
|
||||
|
||||
def growth_factor(cosmo, a):
|
||||
|
@ -155,8 +147,7 @@ def growth_factor_second(cosmo, a):
|
|||
"""
|
||||
if cosmo._flags["gamma_growth"]:
|
||||
raise NotImplementedError(
|
||||
"Gamma growth rate is not implemented for second order growth!"
|
||||
)
|
||||
"Gamma growth rate is not implemented for second order growth!")
|
||||
return None
|
||||
else:
|
||||
return _growth_factor_second_ODE(cosmo, a)
|
||||
|
@ -228,8 +219,7 @@ def growth_rate_second(cosmo, a):
|
|||
"""
|
||||
if cosmo._flags["gamma_growth"]:
|
||||
raise NotImplementedError(
|
||||
"Gamma growth factor is not implemented for second order growth!"
|
||||
)
|
||||
"Gamma growth factor is not implemented for second order growth!")
|
||||
return None
|
||||
else:
|
||||
return _growth_rate_second_ODE(cosmo, a)
|
||||
|
@ -258,23 +248,19 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
|
|||
atab = np.logspace(log10_amin, 0.0, steps)
|
||||
|
||||
def D_derivs(y, x):
|
||||
q = (
|
||||
2.0
|
||||
- 0.5
|
||||
* (
|
||||
Omega_m_a(cosmo, x)
|
||||
+ (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x)
|
||||
)
|
||||
) / x
|
||||
q = (2.0 - 0.5 *
|
||||
(Omega_m_a(cosmo, x) +
|
||||
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
|
||||
r = 1.5 * Omega_m_a(cosmo, x) / x / x
|
||||
|
||||
g1, g2 = y[0]
|
||||
f1, f2 = y[1]
|
||||
dy1da = [f1, -q * f1 + r * g1]
|
||||
dy2da = [f2, -q * f2 + r * g2 - r * g1 ** 2]
|
||||
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
|
||||
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
|
||||
|
||||
y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]])
|
||||
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
|
||||
[1.0, -6.0 / 7 * atab[0]]])
|
||||
y = odeint(D_derivs, y0, atab)
|
||||
|
||||
# compute second order derivatives growth
|
||||
|
@ -473,8 +459,7 @@ def _growth_rate_gamma(cosmo, a):
|
|||
|
||||
see :cite:`2019:Euclid Preparation VII, eqn.32`
|
||||
"""
|
||||
return Omega_m_a(cosmo, a) ** cosmo.gamma
|
||||
|
||||
return Omega_m_a(cosmo, a)**cosmo.gamma
|
||||
|
||||
|
||||
def Gf(cosmo, a):
|
||||
|
@ -503,7 +488,7 @@ def Gf(cosmo, a):
|
|||
"""
|
||||
f1 = growth_rate(cosmo, a)
|
||||
g1 = growth_factor(cosmo, a)
|
||||
D1f = f1*g1/ a
|
||||
D1f = f1 * g1 / a
|
||||
return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
|
||||
|
||||
|
||||
|
@ -532,7 +517,7 @@ def Gf2(cosmo, a):
|
|||
"""
|
||||
f2 = growth_rate_second(cosmo, a)
|
||||
g2 = growth_factor_second(cosmo, a)
|
||||
D2f = f2*g2/ a
|
||||
D2f = f2 * g2 / a
|
||||
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
|
||||
|
||||
|
||||
|
@ -563,13 +548,12 @@ def dGfa(cosmo, a):
|
|||
"""
|
||||
f1 = growth_rate(cosmo, a)
|
||||
g1 = growth_factor(cosmo, a)
|
||||
D1f = f1*g1/ a
|
||||
D1f = f1 * g1 / a
|
||||
cache = cosmo._workspace['background.growth_factor']
|
||||
f1p = cache['h'] / cache['a'] * cache['g']
|
||||
f1p = interp(np.log(a), np.log(cache['a']), f1p)
|
||||
Ea = E(cosmo, a)
|
||||
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) +
|
||||
3 * a**2 * Ea * D1f)
|
||||
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f)
|
||||
|
||||
|
||||
def dGf2a(cosmo, a):
|
||||
|
@ -599,10 +583,9 @@ def dGf2a(cosmo, a):
|
|||
"""
|
||||
f2 = growth_rate_second(cosmo, a)
|
||||
g2 = growth_factor_second(cosmo, a)
|
||||
D2f = f2*g2/ a
|
||||
D2f = f2 * g2 / a
|
||||
cache = cosmo._workspace['background.growth_factor']
|
||||
f2p = cache['h2'] / cache['a'] * cache['g2']
|
||||
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
||||
E = E(cosmo, a)
|
||||
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) +
|
||||
3 * a**2 * E * D2f)
|
||||
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
|
||||
|
|
114
jaxpm/kernels.py
114
jaxpm/kernels.py
|
@ -1,25 +1,27 @@
|
|||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
|
||||
""" Return k_vector given a shape (nc, nc, nc) and box_size
|
||||
""" Return k_vector given a shape (nc, nc, nc) and box_size
|
||||
"""
|
||||
k = []
|
||||
for d in range(len(shape)):
|
||||
kd = np.fft.fftfreq(shape[d])
|
||||
kd *= 2 * np.pi
|
||||
kdshape = np.ones(len(shape), dtype='int')
|
||||
if symmetric and d == len(shape) - 1:
|
||||
kd = kd[:shape[d] // 2 + 1]
|
||||
kdshape[d] = len(kd)
|
||||
kd = kd.reshape(kdshape)
|
||||
k = []
|
||||
for d in range(len(shape)):
|
||||
kd = np.fft.fftfreq(shape[d])
|
||||
kd *= 2 * np.pi
|
||||
kdshape = np.ones(len(shape), dtype='int')
|
||||
if symmetric and d == len(shape) - 1:
|
||||
kd = kd[:shape[d] // 2 + 1]
|
||||
kdshape[d] = len(kd)
|
||||
kd = kd.reshape(kdshape)
|
||||
|
||||
k.append(kd.astype(dtype))
|
||||
del kd, kdshape
|
||||
return k
|
||||
|
||||
k.append(kd.astype(dtype))
|
||||
del kd, kdshape
|
||||
return k
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
"""
|
||||
"""
|
||||
Computes the gradient kernel in the requested direction
|
||||
Parameters:
|
||||
-----------
|
||||
|
@ -32,20 +34,21 @@ def gradient_kernel(kvec, direction, order=1):
|
|||
wts: array
|
||||
Complex kernel
|
||||
"""
|
||||
if order == 0:
|
||||
wts = 1j * kvec[direction]
|
||||
wts = jnp.squeeze(wts)
|
||||
wts[len(wts) // 2] = 0
|
||||
wts = wts.reshape(kvec[direction].shape)
|
||||
return wts
|
||||
else:
|
||||
w = kvec[direction]
|
||||
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
|
||||
wts = a * 1j
|
||||
return wts
|
||||
if order == 0:
|
||||
wts = 1j * kvec[direction]
|
||||
wts = jnp.squeeze(wts)
|
||||
wts[len(wts) // 2] = 0
|
||||
wts = wts.reshape(kvec[direction].shape)
|
||||
return wts
|
||||
else:
|
||||
w = kvec[direction]
|
||||
a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
|
||||
wts = a * 1j
|
||||
return wts
|
||||
|
||||
|
||||
def laplace_kernel(kvec):
|
||||
"""
|
||||
"""
|
||||
Compute the Laplace kernel from a given K vector
|
||||
Parameters:
|
||||
-----------
|
||||
|
@ -56,16 +59,17 @@ def laplace_kernel(kvec):
|
|||
wts: array
|
||||
Complex kernel
|
||||
"""
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
mask = (kk == 0).nonzero()
|
||||
kk[mask] = 1
|
||||
wts = 1. / kk
|
||||
imask = (~(kk == 0)).astype(int)
|
||||
wts *= imask
|
||||
return wts
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
mask = (kk == 0).nonzero()
|
||||
kk[mask] = 1
|
||||
wts = 1. / kk
|
||||
imask = (~(kk == 0)).astype(int)
|
||||
wts *= imask
|
||||
return wts
|
||||
|
||||
|
||||
def longrange_kernel(kvec, r_split):
|
||||
"""
|
||||
"""
|
||||
Computes a long range kernel
|
||||
Parameters:
|
||||
-----------
|
||||
|
@ -78,14 +82,15 @@ def longrange_kernel(kvec, r_split):
|
|||
wts: array
|
||||
kernel
|
||||
"""
|
||||
if r_split != 0:
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
return np.exp(-kk * r_split**2)
|
||||
else:
|
||||
return 1.
|
||||
if r_split != 0:
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
return np.exp(-kk * r_split**2)
|
||||
else:
|
||||
return 1.
|
||||
|
||||
|
||||
def cic_compensation(kvec):
|
||||
"""
|
||||
"""
|
||||
Computes cic compensation kernel.
|
||||
Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499
|
||||
Itself based on equation 18 (with p=2) of
|
||||
|
@ -95,12 +100,13 @@ def cic_compensation(kvec):
|
|||
Returns:
|
||||
v: array of kernel
|
||||
"""
|
||||
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||
return wts
|
||||
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||
return wts
|
||||
|
||||
|
||||
def PGD_kernel(kvec, kl, ks):
|
||||
"""
|
||||
"""
|
||||
Computes the PGD kernel
|
||||
Parameters:
|
||||
-----------
|
||||
|
@ -115,12 +121,12 @@ def PGD_kernel(kvec, kl, ks):
|
|||
v: array
|
||||
kernel
|
||||
"""
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
kl2 = kl**2
|
||||
ks4 = ks**4
|
||||
mask = (kk == 0).nonzero()
|
||||
kk[mask] = 1
|
||||
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
|
||||
imask = (~(kk == 0)).astype(int)
|
||||
v *= imask
|
||||
return v
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
kl2 = kl**2
|
||||
ks4 = ks**4
|
||||
mask = (kk == 0).nonzero()
|
||||
kk[mask] = 1
|
||||
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
|
||||
imask = (~(kk == 0)).astype(int)
|
||||
v *= imask
|
||||
return v
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo.constants as constants
|
||||
import jax_cosmo
|
||||
|
||||
import jax_cosmo.constants as constants
|
||||
from jax.scipy.ndimage import map_coordinates
|
||||
from jaxpm.utils import gaussian_smoothing
|
||||
|
||||
from jaxpm.painting import cic_paint_2d
|
||||
from jaxpm.utils import gaussian_smoothing
|
||||
|
||||
|
||||
def density_plane(positions,
|
||||
box_shape,
|
||||
|
@ -26,9 +27,11 @@ def density_plane(positions,
|
|||
xy = xy / nx * plane_resolution
|
||||
|
||||
# Selecting only particles that fall inside the volume of interest
|
||||
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
|
||||
weight = jnp.where(
|
||||
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
|
||||
# Painting density plane
|
||||
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
|
||||
density_plane = cic_paint_2d(
|
||||
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
|
||||
|
||||
# Apply density normalization
|
||||
density_plane = density_plane / ((nx / plane_resolution) *
|
||||
|
@ -36,17 +39,13 @@ def density_plane(positions,
|
|||
|
||||
# Apply Gaussian smoothing if requested
|
||||
if smoothing_sigma is not None:
|
||||
density_plane = gaussian_smoothing(density_plane,
|
||||
smoothing_sigma)
|
||||
density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
|
||||
|
||||
return density_plane
|
||||
|
||||
|
||||
def convergence_Born(cosmo,
|
||||
density_planes,
|
||||
coords,
|
||||
z_source):
|
||||
"""
|
||||
def convergence_Born(cosmo, density_planes, coords, z_source):
|
||||
"""
|
||||
Compute the Born convergence
|
||||
Args:
|
||||
cosmo: `Cosmology`, cosmology object.
|
||||
|
@ -57,24 +56,27 @@ def convergence_Born(cosmo,
|
|||
Returns:
|
||||
`Tensor` of shape [batch_size, N, Nz], of convergence values.
|
||||
"""
|
||||
# Compute constant prefactor:
|
||||
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
|
||||
# Compute comoving distance of source galaxies
|
||||
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
|
||||
# Compute constant prefactor:
|
||||
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
|
||||
# Compute comoving distance of source galaxies
|
||||
r_s = jax_cosmo.background.radial_comoving_distance(
|
||||
cosmo, 1 / (1 + z_source))
|
||||
|
||||
convergence = 0
|
||||
for entry in density_planes:
|
||||
r = entry['r']; a = entry['a']; p = entry['plane']
|
||||
dx = entry['dx']; dz = entry['dz']
|
||||
# Normalize density planes
|
||||
density_normalization = dz * r / a
|
||||
p = (p - p.mean()) * constant_factor * density_normalization
|
||||
convergence = 0
|
||||
for entry in density_planes:
|
||||
r = entry['r']
|
||||
a = entry['a']
|
||||
p = entry['plane']
|
||||
dx = entry['dx']
|
||||
dz = entry['dz']
|
||||
# Normalize density planes
|
||||
density_normalization = dz * r / a
|
||||
p = (p - p.mean()) * constant_factor * density_normalization
|
||||
|
||||
# Interpolate at the density plane coordinates
|
||||
im = map_coordinates(p,
|
||||
coords * r / dx - 0.5,
|
||||
order=1, mode="wrap")
|
||||
# Interpolate at the density plane coordinates
|
||||
im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
|
||||
|
||||
convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1])
|
||||
convergence += im * jnp.clip(1. -
|
||||
(r / r_s), 0, 1000).reshape([-1, 1, 1])
|
||||
|
||||
return convergence
|
||||
return convergence
|
||||
|
|
54
jaxpm/nn.py
54
jaxpm/nn.py
|
@ -1,6 +1,7 @@
|
|||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import haiku as hk
|
||||
|
||||
|
||||
def _deBoorVectorized(x, t, c, p):
|
||||
"""
|
||||
|
@ -13,48 +14,47 @@ def _deBoorVectorized(x, t, c, p):
|
|||
c: array of control points
|
||||
p: degree of B-spline
|
||||
"""
|
||||
k = jnp.digitize(x, t) -1
|
||||
k = jnp.digitize(x, t) - 1
|
||||
|
||||
d = [c[j + k - p] for j in range(0, p+1)]
|
||||
for r in range(1, p+1):
|
||||
for j in range(p, r-1, -1):
|
||||
alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p])
|
||||
d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j]
|
||||
d = [c[j + k - p] for j in range(0, p + 1)]
|
||||
for r in range(1, p + 1):
|
||||
for j in range(p, r - 1, -1):
|
||||
alpha = (x - t[j + k - p]) / (t[j + 1 + k - r] - t[j + k - p])
|
||||
d[j] = (1.0 - alpha) * d[j - 1] + alpha * d[j]
|
||||
return d[p]
|
||||
|
||||
|
||||
class NeuralSplineFourierFilter(hk.Module):
|
||||
"""A rotationally invariant filter parameterized by
|
||||
"""A rotationally invariant filter parameterized by
|
||||
a b-spline with parameters specified by a small NN."""
|
||||
|
||||
def __init__(self, n_knots=8, latent_size=16, name=None):
|
||||
"""
|
||||
def __init__(self, n_knots=8, latent_size=16, name=None):
|
||||
"""
|
||||
n_knots: number of control points for the spline
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
self.n_knots = n_knots
|
||||
self.latent_size = latent_size
|
||||
super().__init__(name=name)
|
||||
self.n_knots = n_knots
|
||||
self.latent_size = latent_size
|
||||
|
||||
def __call__(self, x, a):
|
||||
"""
|
||||
def __call__(self, x, a):
|
||||
"""
|
||||
x: array, scale, normalized to fftfreq default
|
||||
a: scalar, scale factor
|
||||
"""
|
||||
|
||||
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
|
||||
net = jnp.sin(hk.Linear(self.latent_size)(net))
|
||||
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
|
||||
net = jnp.sin(hk.Linear(self.latent_size)(net))
|
||||
|
||||
w = hk.Linear(self.n_knots+1)(net)
|
||||
k = hk.Linear(self.n_knots-1)(net)
|
||||
w = hk.Linear(self.n_knots + 1)(net)
|
||||
k = hk.Linear(self.n_knots - 1)(net)
|
||||
|
||||
# make sure the knots sum to 1 and are in the interval 0,1
|
||||
k = jnp.concatenate([jnp.zeros((1,)),
|
||||
jnp.cumsum(jax.nn.softmax(k))])
|
||||
# make sure the knots sum to 1 and are in the interval 0,1
|
||||
k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))])
|
||||
|
||||
w = jnp.concatenate([jnp.zeros((1,)),
|
||||
w])
|
||||
w = jnp.concatenate([jnp.zeros((1, )), w])
|
||||
|
||||
# Augment with repeating points
|
||||
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
|
||||
# Augment with repeating points
|
||||
ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))])
|
||||
|
||||
return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
|
||||
return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w,
|
||||
3)
|
||||
|
|
|
@ -1,98 +1,100 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jaxpm.kernels import cic_compensation, fftk
|
||||
|
||||
from jaxpm.kernels import fftk, cic_compensation
|
||||
|
||||
def cic_paint(mesh, positions, weight=None):
|
||||
""" Paints positions onto mesh
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||
[0., 1, 1], [1., 1, 1]]])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
if weight is not None:
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
if weight is not None:
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
|
||||
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
jnp.array(mesh.shape))
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1,
|
||||
2))
|
||||
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
|
||||
dnums)
|
||||
return mesh
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1, 2))
|
||||
mesh = lax.scatter_add(mesh,
|
||||
neighboor_coords,
|
||||
kernel.reshape([-1,8]),
|
||||
dnums)
|
||||
return mesh
|
||||
|
||||
def cic_read(mesh, positions):
|
||||
""" Paints positions onto mesh
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||
[0., 1, 1], [1., 1, 1]]])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape))
|
||||
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
||||
jnp.array(mesh.shape))
|
||||
|
||||
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
||||
|
||||
return (mesh[neighboor_coords[...,0],
|
||||
neighboor_coords[...,1],
|
||||
neighboor_coords[...,3]]*kernel).sum(axis=-1)
|
||||
|
||||
def cic_paint_2d(mesh, positions, weight):
|
||||
""" Paints positions onto a 2d mesh
|
||||
""" Paints positions onto a 2d mesh
|
||||
mesh: [nx, ny]
|
||||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1]
|
||||
if weight is not None:
|
||||
kernel = kernel * weight[...,jnp.newaxis]
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1]
|
||||
if weight is not None:
|
||||
kernel = kernel * weight[..., jnp.newaxis]
|
||||
|
||||
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
||||
jnp.array(mesh.shape))
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1),
|
||||
scatter_dims_to_operand_dims=(0,
|
||||
1))
|
||||
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
|
||||
dnums)
|
||||
return mesh
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(),
|
||||
inserted_window_dims=(0, 1),
|
||||
scatter_dims_to_operand_dims=(0, 1))
|
||||
mesh = lax.scatter_add(mesh,
|
||||
neighboor_coords,
|
||||
kernel.reshape([-1,4]),
|
||||
dnums)
|
||||
return mesh
|
||||
|
||||
def compensate_cic(field):
|
||||
"""
|
||||
"""
|
||||
Compensate for CiC painting
|
||||
Args:
|
||||
field: input 3D cic-painted field
|
||||
Returns:
|
||||
compensated_field
|
||||
"""
|
||||
nc = field.shape
|
||||
kvec = fftk(nc)
|
||||
nc = field.shape
|
||||
kvec = fftk(nc)
|
||||
|
||||
delta_k = jnp.fft.rfftn(field)
|
||||
delta_k = cic_compensation(kvec) * delta_k
|
||||
return jnp.fft.irfftn(delta_k)
|
||||
delta_k = jnp.fft.rfftn(field)
|
||||
delta_k = cic_compensation(kvec) * delta_k
|
||||
return jnp.fft.irfftn(delta_k)
|
||||
|
|
44
jaxpm/pm.py
44
jaxpm/pm.py
|
@ -1,11 +1,12 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import jax_cosmo as jc
|
||||
|
||||
from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel, PGD_kernel
|
||||
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
|
||||
longrange_kernel)
|
||||
from jaxpm.painting import cic_paint, cic_read
|
||||
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
||||
|
||||
|
||||
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
||||
"""
|
||||
|
@ -21,10 +22,14 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
|||
delta_k = jnp.fft.rfftn(delta)
|
||||
|
||||
# Computes gravitational potential
|
||||
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
|
||||
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
|
||||
r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions)
|
||||
for i in range(3)],axis=-1)
|
||||
return jnp.stack([
|
||||
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
|
||||
def lpt(cosmo, initial_conditions, positions, a):
|
||||
|
@ -34,23 +39,29 @@ def lpt(cosmo, initial_conditions, positions, a):
|
|||
initial_force = pm_forces(positions, delta=initial_conditions)
|
||||
a = jnp.atleast_1d(a)
|
||||
dx = growth_factor(cosmo, a) * initial_force
|
||||
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
|
||||
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force
|
||||
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
|
||||
a)) * dx
|
||||
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo,
|
||||
a) * initial_force
|
||||
return dx, p, f
|
||||
|
||||
|
||||
def linear_field(mesh_shape, box_size, pk, seed):
|
||||
"""
|
||||
Generate initial conditions.
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
|
||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||
for i, kk in enumerate(kvec))**0.5
|
||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
||||
box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
field = jax.random.normal(seed, mesh_shape)
|
||||
field = jnp.fft.rfftn(field) * pkmesh**0.5
|
||||
field = jnp.fft.irfftn(field)
|
||||
return field
|
||||
|
||||
|
||||
def make_ode_fn(mesh_shape):
|
||||
|
||||
def nbody_ode(state, a, cosmo):
|
||||
|
@ -84,13 +95,16 @@ def pgd_correction(pos, params):
|
|||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||
alpha, kl, ks = params
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
PGD_range=PGD_kernel(kvec, kl, ks)
|
||||
PGD_range = PGD_kernel(kvec, kl, ks)
|
||||
|
||||
pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
|
||||
pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
|
||||
|
||||
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
|
||||
for i in range(3)],axis=-1)
|
||||
forces_pgd = jnp.stack([
|
||||
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
dpos_pgd = forces_pgd*alpha
|
||||
dpos_pgd = forces_pgd * alpha
|
||||
|
||||
return dpos_pgd
|
103
jaxpm/utils.py
103
jaxpm/utils.py
|
@ -1,41 +1,42 @@
|
|||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.scipy.stats import norm
|
||||
|
||||
__all__ = ['power_spectrum']
|
||||
|
||||
|
||||
def _initialize_pk(shape, boxsize, kmin, dk):
|
||||
"""
|
||||
"""
|
||||
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
|
||||
"""
|
||||
I = np.eye(len(shape), dtype='int') * -2 + 1
|
||||
I = np.eye(len(shape), dtype='int') * -2 + 1
|
||||
|
||||
W = np.empty(shape, dtype='f4')
|
||||
W[...] = 2.0
|
||||
W[..., 0] = 1.0
|
||||
W[..., -1] = 1.0
|
||||
W = np.empty(shape, dtype='f4')
|
||||
W[...] = 2.0
|
||||
W[..., 0] = 1.0
|
||||
W[..., -1] = 1.0
|
||||
|
||||
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
|
||||
kedges = np.arange(kmin, kmax, dk)
|
||||
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
|
||||
kedges = np.arange(kmin, kmax, dk)
|
||||
|
||||
k = [
|
||||
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
|
||||
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
|
||||
]
|
||||
kmag = sum(ki**2 for ki in k)**0.5
|
||||
k = [
|
||||
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
|
||||
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
|
||||
]
|
||||
kmag = sum(ki**2 for ki in k)**0.5
|
||||
|
||||
xsum = np.zeros(len(kedges) + 1)
|
||||
Nsum = np.zeros(len(kedges) + 1)
|
||||
xsum = np.zeros(len(kedges) + 1)
|
||||
Nsum = np.zeros(len(kedges) + 1)
|
||||
|
||||
dig = np.digitize(kmag.flat, kedges)
|
||||
dig = np.digitize(kmag.flat, kedges)
|
||||
|
||||
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
|
||||
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
|
||||
return dig, Nsum, xsum, W, k, kedges
|
||||
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
|
||||
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
|
||||
return dig, Nsum, xsum, W, k, kedges
|
||||
|
||||
|
||||
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
||||
"""
|
||||
"""
|
||||
Calculate the powerspectra given real space field
|
||||
|
||||
Args:
|
||||
|
@ -51,49 +52,49 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
|||
power: real valued array of power in each bin
|
||||
|
||||
"""
|
||||
shape = field.shape
|
||||
nx, ny, nz = shape
|
||||
shape = field.shape
|
||||
nx, ny, nz = shape
|
||||
|
||||
#initialze values related to powerspectra (mode bins and weights)
|
||||
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
||||
#initialze values related to powerspectra (mode bins and weights)
|
||||
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
||||
|
||||
#fast fourier transform
|
||||
fft_image = jnp.fft.fftn(field)
|
||||
#fast fourier transform
|
||||
fft_image = jnp.fft.fftn(field)
|
||||
|
||||
#absolute value of fast fourier transform
|
||||
pk = jnp.real(fft_image * jnp.conj(fft_image))
|
||||
#absolute value of fast fourier transform
|
||||
pk = jnp.real(fft_image * jnp.conj(fft_image))
|
||||
|
||||
#calculating powerspectra
|
||||
real = jnp.real(pk).reshape([-1])
|
||||
imag = jnp.imag(pk).reshape([-1])
|
||||
|
||||
#calculating powerspectra
|
||||
real = jnp.real(pk).reshape([-1])
|
||||
imag = jnp.imag(pk).reshape([-1])
|
||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
|
||||
length=xsum.size) * 1j
|
||||
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
||||
|
||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
|
||||
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
||||
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
||||
|
||||
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
||||
#normalization for powerspectra
|
||||
norm = np.prod(np.array(shape[:])).astype('float32')**2
|
||||
|
||||
#normalization for powerspectra
|
||||
norm = np.prod(np.array(shape[:])).astype('float32')**2
|
||||
#find central values of each bin
|
||||
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
||||
|
||||
#find central values of each bin
|
||||
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
||||
return kbins, P / norm
|
||||
|
||||
return kbins, P / norm
|
||||
|
||||
def gaussian_smoothing(im, sigma):
|
||||
"""
|
||||
"""
|
||||
im: 2d image
|
||||
sigma: smoothing scale in px
|
||||
"""
|
||||
# Compute k vector
|
||||
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
|
||||
jnp.fft.fftfreq(im.shape[1])),
|
||||
axis=-1)
|
||||
k = jnp.linalg.norm(kvec, axis=-1)
|
||||
# We compute the value of the filter at frequency k
|
||||
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
|
||||
filter /= filter[0,0]
|
||||
|
||||
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
|
||||
# Compute k vector
|
||||
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
|
||||
jnp.fft.fftfreq(im.shape[1])),
|
||||
axis=-1)
|
||||
k = jnp.linalg.norm(kvec, axis=-1)
|
||||
# We compute the value of the filter at frequency k
|
||||
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
|
||||
filter /= filter[0, 0]
|
||||
|
||||
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
|
||||
|
|
2
setup.py
2
setup.py
|
@ -1,4 +1,4 @@
|
|||
from setuptools import setup, find_packages
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name='JaxPM',
|
||||
|
|
Loading…
Add table
Reference in a new issue