Applying formatting

This commit is contained in:
EiffL 2024-07-09 14:54:34 -04:00
parent 835fa89aec
commit f28442bb48
14 changed files with 565 additions and 445 deletions

17
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,17 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
args: ['--parallel', '--in-place']
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)

View file

@ -1,31 +1,46 @@
# Can be executed with:
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
import jax.lax as lax
from jax.experimental.maps import xmap
from jax.experimental.maps import Mesh
from jax.experimental.maps import Mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
from functools import partial
jax.distributed.initialize()
cube_size = 2048
@partial(xmap,
in_axes=[...],
out_axes=['x','y', ...],
axis_sizes={'x':cube_size, 'y':cube_size},
axis_resources={'x': 'nx', 'y':'ny',
'key_x':'nx', 'key_y':'ny'})
out_axes=['x', 'y', ...],
axis_sizes={
'x': cube_size,
'y': cube_size
},
axis_resources={
'x': 'nx',
'y': 'ny',
'key_x': 'nx',
'key_y': 'ny'
})
def pnormal(key):
return jax.random.normal(key, shape=[cube_size])
@partial(xmap,
in_axes={0:'x', 1:'y'},
out_axes=['x','y', ...],
axis_resources={'x': 'nx', 'y': 'ny'})
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pfft3d(mesh):
# [x, y, z]
@ -37,10 +52,17 @@ def pfft3d(mesh):
# [z, x, y]
return mesh
@partial(xmap,
in_axes={0:'x', 1:'y'},
out_axes=['x','y', ...],
axis_resources={'x': 'nx', 'y': 'ny'})
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pifft3d(mesh):
# [z, x, y]
@ -52,6 +74,7 @@ def pifft3d(mesh):
# [x, y, z]
return mesh
key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape((2,2,2))

View file

@ -1,17 +1,21 @@
# Start this script with:
# mpirun -np 4 python test_script.py
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
import matplotlib.pylab as plt
import jax
import numpy as np
import jax.numpy as jnp
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as np
import tensorflow_probability as tfp
from jax.experimental.maps import mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfp = tfp.substrates.jax
tfd = tfp.distributions
def cic_paint(mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
@ -19,26 +23,27 @@ def cic_paint(mesh, positions):
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords.reshape([-1,8,3]).astype('int32'),
kernel.reshape([-1,8]),
dnums)
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(
mesh,
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
kernel.reshape([-1, 8]), dnums)
return mesh
# And let's draw some points from some 3D distribution
dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.)
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
scale_identity_multiplier=3.)
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
f = pjit(lambda x: cic_paint(x, pos),

View file

@ -1,11 +1,12 @@
import jax
import jax.numpy as jnp
import jax.lax as lax
from functools import partial
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit, PartitionSpec
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.experimental.pjit import PartitionSpec, pjit
import jaxpm.painting as paint
# TODO: add a way to configure axis resources from command line
@ -14,35 +15,59 @@ mesh_size = {'nx': 2, 'ny': 2}
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},
{0: 'x', 2: 'y'},
{0: 'x', 2: 'y'}),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def stack3d(a, b, c):
return jnp.stack([a, b, c], axis=-1)
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},[...]),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}, [...]),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def scalar_multiply(a, factor):
return a * factor
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},
{0: 'x', 2: 'y'}),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def add(a, b):
return a + b
@partial(xmap,
in_axes=['x', 'y',...],
out_axes=['x', 'y',...],
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def fft3d(mesh):
""" Performs a 3D complex Fourier transform
@ -62,8 +87,8 @@ def fft3d(mesh):
@partial(xmap,
in_axes=['x', 'y',...],
out_axes=['x', 'y',...],
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def ifft3d(mesh):
mesh = jnp.fft.ifft(mesh)
@ -72,10 +97,15 @@ def ifft3d(mesh):
mesh = lax.all_to_all(mesh, 'x', 0, 0)
return jnp.fft.ifft(mesh).real
def normal(key, shape=[]):
@partial(xmap,
in_axes=['x', 'y',...],
out_axes={0: 'x', 2: 'y'},
in_axes=['x', 'y', ...],
out_axes={
0: 'x',
2: 'y'
},
axis_resources=axis_resources)
def fn(key):
""" Generate a distributed random normal distributions
@ -83,99 +113,126 @@ def normal(key, shape=[]):
key: array of random keys with same layout as computational mesh
shape: logical shape of array to sample
"""
return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'],
shape[1]//mesh_size['ny']]+shape[2:])
return jax.random.normal(
key,
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
shape[2:])
return fn(key)
@partial(xmap,
in_axes=(['x', 'y', ...],
[['x'], ['y'], [...]], [...], [...]),
in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
@jax.jit
def scale_by_power_spectrum(kfield, kvec, k, pk):
kx, ky, kz = kvec
kk = jnp.sqrt(kx**2 + ky ** 2 + kz**2)
kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
@partial(xmap,
in_axes=(['x', 'y', 'z'],
[['x'], ['y'], ['z']]),
in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
out_axes=(['x', 'y', 'z']),
axis_resources=axis_resources)
def gradient_laplace_kernel(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk)
return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
kernel = jnp.where(kk == 0, 1., 1. / kk)
return (kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
@partial(xmap,
in_axes=([...]),
out_axes={0: 'x', 2: 'y'},
axis_sizes={'x': mesh_size['nx'],
'y': mesh_size['ny']},
out_axes={
0: 'x',
2: 'y'
},
axis_sizes={
'x': mesh_size['nx'],
'y': mesh_size['ny']
},
axis_resources=axis_resources)
def meshgrid(x, y, z):
""" Generates a mesh grid of appropriate size for the
computational mesh we have.
"""
return jnp.stack(jnp.meshgrid(x,
y,
z), axis=-1)
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
def cic_paint(pos, mesh_shape, halo_size=0):
@partial(xmap,
in_axes=({0: 'x', 2: 'y'}),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(pos):
mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size,
mesh_shape[1]//mesh_size['ny']+2*halo_size]
+ mesh_shape[2:])
mesh = jnp.zeros([
mesh_shape[0] // mesh_size['nx'] +
2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
] + mesh_shape[2:])
# Paint particles
mesh = paint.cic_paint(mesh, pos.reshape(-1, 3) +
mesh = paint.cic_paint(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
# Perform halo exchange
# Halo exchange along x
left = lax.pshuffle(mesh[-2*halo_size:],
left = lax.pshuffle(mesh[-2 * halo_size:],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
right = lax.pshuffle(mesh[:2*halo_size],
right = lax.pshuffle(mesh[:2 * halo_size],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
mesh = mesh.at[:2*halo_size].add(left)
mesh = mesh.at[-2*halo_size:].add(right)
mesh = mesh.at[:2 * halo_size].add(left)
mesh = mesh.at[-2 * halo_size:].add(right)
# Halo exchange along y
left = lax.pshuffle(mesh[:, -2*halo_size:],
left = lax.pshuffle(mesh[:, -2 * halo_size:],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
right = lax.pshuffle(mesh[:, :2*halo_size],
right = lax.pshuffle(mesh[:, :2 * halo_size],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
mesh = mesh.at[:, :2*halo_size].add(left)
mesh = mesh.at[:, -2*halo_size:].add(right)
mesh = mesh.at[:, :2 * halo_size].add(left)
mesh = mesh.at[:, -2 * halo_size:].add(right)
# removing halo and returning mesh
return mesh[halo_size:-halo_size, halo_size:-halo_size]
return fn(pos)
def cic_read(mesh, pos, halo_size=0):
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},
{0: 'x', 2: 'y'},),
out_axes=({0: 'x', 2: 'y'}),
in_axes=(
{
0: 'x',
2: 'y'
},
{
0: 'x',
2: 'y'
},
),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(mesh, pos):
@ -198,7 +255,9 @@ def cic_read(mesh, pos, halo_size=0):
mesh = jnp.concatenate([left, mesh, right], axis=1)
# Reading field at particles positions
res = paint.cic_read(mesh, pos.reshape(-1, 3) +
res = paint.cic_read(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
return res.reshape(pos.shape[:-1])
@ -215,8 +274,10 @@ def reshape_dense_to_split(x):
data should be necessary
"""
shape = list(x.shape)
return x.reshape([mesh_size['nx'], shape[0]//mesh_size['nx'],
mesh_size['ny'], shape[2]//mesh_size['ny']] + shape[2:])
return x.reshape([
mesh_size['nx'], shape[0] //
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
] + shape[2:])
@partial(pjit,
@ -228,4 +289,4 @@ def reshape_split_to_dense(x):
data should be necessary
"""
shape = list(x.shape)
return x.reshape([shape[0]*shape[1], shape[2]*shape[3]] + shape[4:])
return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])

View file

@ -1,13 +1,14 @@
import jax
from jax.lax import linear_solve_p
import jax.numpy as jnp
from jax.experimental.maps import xmap
from functools import partial
import jax_cosmo as jc
from jaxpm.kernels import fftk
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.lax import linear_solve_p
import jaxpm.experimental.distributed_ops as dops
from jaxpm.growth import growth_factor, growth_rate, dGfa
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import fftk
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
@ -25,8 +26,10 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
# Recovers forces at particle positions
forces = [dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)),
positions, halo_size) for f in forces_k]
forces = [
dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
halo_size) for f in forces_k
]
return dops.stack3d(*forces)
@ -44,12 +47,14 @@ def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
field = dops.fft3d(dops.reshape_split_to_dense(field))
# Rescaling k to physical units
kvec = [k.squeeze() / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(mesh_shape, symmetric=False))]
kvec = [
k.squeeze() / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
]
k = jnp.logspace(-4, 2, 256)
pk = jc.power.linear_matter_power(cosmo, k)
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
) / (box_size[0] * box_size[1] * box_size[2])
pk = pk * (mesh_shape[0] * mesh_shape[1] *
mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
@ -66,8 +71,9 @@ def lpt(cosmo, initial_conditions, positions, a):
initial_force = pm_forces(positions, delta_k=initial_conditions)
a = jnp.atleast_1d(a)
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
jnp.sqrt(jc.background.Esqr(cosmo, a)))
p = dops.scalar_multiply(
dx,
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
return dx, p

View file

@ -1,8 +1,8 @@
import jax.numpy as np
from jax_cosmo.background import *
from jax_cosmo.scipy.interpolate import interp
from jax_cosmo.scipy.ode import odeint
from jax_cosmo.background import *
def E(cosmo, a):
r"""Scale factor dependent factor E(a) in the Hubble
@ -52,12 +52,8 @@ def df_de(cosmo, a, epsilon=1e-5):
\frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)-
\frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)}
"""
return (
3
* cosmo.wa
* (np.log(a - epsilon) - (a - 1) / (a - epsilon))
/ np.power(np.log(a - epsilon), 2)
)
return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) /
np.power(np.log(a - epsilon), 2))
def dEa(cosmo, a):
@ -89,15 +85,11 @@ def dEa(cosmo, a):
where :math:`f(a)` is the Dark Energy evolution parameter computed
by :py:meth:`.f_de`.
"""
return (
0.5
* (
-3 * cosmo.Omega_m * np.power(a, -4)
- 2 * cosmo.Omega_k * np.power(a, -3)
+ df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))
)
/ np.power(Esqr(cosmo, a), 0.5)
)
return (0.5 *
(-3 * cosmo.Omega_m * np.power(a, -4) -
2 * cosmo.Omega_k * np.power(a, -3) +
df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) /
np.power(Esqr(cosmo, a), 0.5))
def growth_factor(cosmo, a):
@ -155,8 +147,7 @@ def growth_factor_second(cosmo, a):
"""
if cosmo._flags["gamma_growth"]:
raise NotImplementedError(
"Gamma growth rate is not implemented for second order growth!"
)
"Gamma growth rate is not implemented for second order growth!")
return None
else:
return _growth_factor_second_ODE(cosmo, a)
@ -228,8 +219,7 @@ def growth_rate_second(cosmo, a):
"""
if cosmo._flags["gamma_growth"]:
raise NotImplementedError(
"Gamma growth factor is not implemented for second order growth!"
)
"Gamma growth factor is not implemented for second order growth!")
return None
else:
return _growth_rate_second_ODE(cosmo, a)
@ -258,23 +248,19 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
atab = np.logspace(log10_amin, 0.0, steps)
def D_derivs(y, x):
q = (
2.0
- 0.5
* (
Omega_m_a(cosmo, x)
+ (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x)
)
) / x
q = (2.0 - 0.5 *
(Omega_m_a(cosmo, x) +
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
r = 1.5 * Omega_m_a(cosmo, x) / x / x
g1, g2 = y[0]
f1, f2 = y[1]
dy1da = [f1, -q * f1 + r * g1]
dy2da = [f2, -q * f2 + r * g2 - r * g1 ** 2]
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]])
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
[1.0, -6.0 / 7 * atab[0]]])
y = odeint(D_derivs, y0, atab)
# compute second order derivatives growth
@ -473,8 +459,7 @@ def _growth_rate_gamma(cosmo, a):
see :cite:`2019:Euclid Preparation VII, eqn.32`
"""
return Omega_m_a(cosmo, a) ** cosmo.gamma
return Omega_m_a(cosmo, a)**cosmo.gamma
def Gf(cosmo, a):
@ -503,7 +488,7 @@ def Gf(cosmo, a):
"""
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1*g1/ a
D1f = f1 * g1 / a
return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
@ -532,7 +517,7 @@ def Gf2(cosmo, a):
"""
f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a)
D2f = f2*g2/ a
D2f = f2 * g2 / a
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
@ -563,13 +548,12 @@ def dGfa(cosmo, a):
"""
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1*g1/ a
D1f = f1 * g1 / a
cache = cosmo._workspace['background.growth_factor']
f1p = cache['h'] / cache['a'] * cache['g']
f1p = interp(np.log(a), np.log(cache['a']), f1p)
Ea = E(cosmo, a)
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) +
3 * a**2 * Ea * D1f)
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f)
def dGf2a(cosmo, a):
@ -599,10 +583,9 @@ def dGf2a(cosmo, a):
"""
f2 = growth_rate_second(cosmo, a)
g2 = growth_factor_second(cosmo, a)
D2f = f2*g2/ a
D2f = f2 * g2 / a
cache = cosmo._workspace['background.growth_factor']
f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p)
E = E(cosmo, a)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) +
3 * a**2 * E * D2f)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)

View file

@ -1,5 +1,6 @@
import numpy as np
import jax.numpy as jnp
import numpy as np
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
""" Return k_vector given a shape (nc, nc, nc) and box_size
@ -18,6 +19,7 @@ def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
del kd, kdshape
return k
def gradient_kernel(kvec, direction, order=1):
"""
Computes the gradient kernel in the requested direction
@ -44,6 +46,7 @@ def gradient_kernel(kvec, direction, order=1):
wts = a * 1j
return wts
def laplace_kernel(kvec):
"""
Compute the Laplace kernel from a given K vector
@ -64,6 +67,7 @@ def laplace_kernel(kvec):
wts *= imask
return wts
def longrange_kernel(kvec, r_split):
"""
Computes a long range kernel
@ -84,6 +88,7 @@ def longrange_kernel(kvec, r_split):
else:
return 1.
def cic_compensation(kvec):
"""
Computes cic compensation kernel.
@ -99,6 +104,7 @@ def cic_compensation(kvec):
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts
def PGD_kernel(kvec, kl, ks):
"""
Computes the PGD kernel

View file

@ -1,11 +1,12 @@
import jax
import jax.numpy as jnp
import jax_cosmo.constants as constants
import jax_cosmo
import jax_cosmo.constants as constants
from jax.scipy.ndimage import map_coordinates
from jaxpm.utils import gaussian_smoothing
from jaxpm.painting import cic_paint_2d
from jaxpm.utils import gaussian_smoothing
def density_plane(positions,
box_shape,
@ -26,9 +27,11 @@ def density_plane(positions,
xy = xy / nx * plane_resolution
# Selecting only particles that fall inside the volume of interest
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
weight = jnp.where(
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
# Painting density plane
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
density_plane = cic_paint_2d(
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
# Apply density normalization
density_plane = density_plane / ((nx / plane_resolution) *
@ -36,16 +39,12 @@ def density_plane(positions,
# Apply Gaussian smoothing if requested
if smoothing_sigma is not None:
density_plane = gaussian_smoothing(density_plane,
smoothing_sigma)
density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
return density_plane
def convergence_Born(cosmo,
density_planes,
coords,
z_source):
def convergence_Born(cosmo, density_planes, coords, z_source):
"""
Compute the Born convergence
Args:
@ -60,21 +59,24 @@ def convergence_Born(cosmo,
# Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
r_s = jax_cosmo.background.radial_comoving_distance(
cosmo, 1 / (1 + z_source))
convergence = 0
for entry in density_planes:
r = entry['r']; a = entry['a']; p = entry['plane']
dx = entry['dx']; dz = entry['dz']
r = entry['r']
a = entry['a']
p = entry['plane']
dx = entry['dx']
dz = entry['dz']
# Normalize density planes
density_normalization = dz * r / a
p = (p - p.mean()) * constant_factor * density_normalization
# Interpolate at the density plane coordinates
im = map_coordinates(p,
coords * r / dx - 0.5,
order=1, mode="wrap")
im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1])
convergence += im * jnp.clip(1. -
(r / r_s), 0, 1000).reshape([-1, 1, 1])
return convergence

View file

@ -1,6 +1,7 @@
import haiku as hk
import jax
import jax.numpy as jnp
import haiku as hk
def _deBoorVectorized(x, t, c, p):
"""
@ -13,13 +14,13 @@ def _deBoorVectorized(x, t, c, p):
c: array of control points
p: degree of B-spline
"""
k = jnp.digitize(x, t) -1
k = jnp.digitize(x, t) - 1
d = [c[j + k - p] for j in range(0, p+1)]
for r in range(1, p+1):
for j in range(p, r-1, -1):
alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p])
d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j]
d = [c[j + k - p] for j in range(0, p + 1)]
for r in range(1, p + 1):
for j in range(p, r - 1, -1):
alpha = (x - t[j + k - p]) / (t[j + 1 + k - r] - t[j + k - p])
d[j] = (1.0 - alpha) * d[j - 1] + alpha * d[j]
return d[p]
@ -44,17 +45,16 @@ class NeuralSplineFourierFilter(hk.Module):
net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
net = jnp.sin(hk.Linear(self.latent_size)(net))
w = hk.Linear(self.n_knots+1)(net)
k = hk.Linear(self.n_knots-1)(net)
w = hk.Linear(self.n_knots + 1)(net)
k = hk.Linear(self.n_knots - 1)(net)
# make sure the knots sum to 1 and are in the interval 0,1
k = jnp.concatenate([jnp.zeros((1,)),
jnp.cumsum(jax.nn.softmax(k))])
k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))])
w = jnp.concatenate([jnp.zeros((1,)),
w])
w = jnp.concatenate([jnp.zeros((1, )), w])
# Augment with repeating points
ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))])
return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w,
3)

View file

@ -1,8 +1,9 @@
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax.numpy as jnp
from jaxpm.kernels import cic_compensation, fftk
from jaxpm.kernels import fftk, cic_compensation
def cic_paint(mesh, positions, weight=None):
""" Paints positions onto mesh
@ -11,9 +12,8 @@ def cic_paint(mesh, positions, weight=None):
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
@ -21,18 +21,19 @@ def cic_paint(mesh, positions, weight=None):
if weight is not None:
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,8]),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
dnums)
return mesh
def cic_read(mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
@ -40,19 +41,19 @@ def cic_read(mesh, positions):
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape))
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
return (mesh[neighboor_coords[...,0],
neighboor_coords[...,1],
neighboor_coords[...,3]]*kernel).sum(axis=-1)
def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh
@ -68,20 +69,21 @@ def cic_paint_2d(mesh, positions, weight):
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[...,jnp.newaxis]
kernel = kernel * weight[..., jnp.newaxis]
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0, 1))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,4]),
scatter_dims_to_operand_dims=(0,
1))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
dnums)
return mesh
def compensate_cic(field):
"""
Compensate for CiC painting

View file

@ -1,11 +1,12 @@
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel, PGD_kernel
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
longrange_kernel)
from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
"""
@ -21,10 +22,14 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
delta_k = jnp.fft.rfftn(delta)
# Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
r_split=r_split)
# Computes gravitational forces
return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions)
for i in range(3)],axis=-1)
return jnp.stack([
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions)
for i in range(3)
],
axis=-1)
def lpt(cosmo, initial_conditions, positions, a):
@ -34,23 +39,29 @@ def lpt(cosmo, initial_conditions, positions, a):
initial_force = pm_forces(positions, delta=initial_conditions)
a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo,
a) * initial_force
return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed):
"""
Generate initial conditions.
"""
kvec = fftk(mesh_shape)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape)
field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field)
return field
def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo):
@ -84,13 +95,16 @@ def pgd_correction(pos, params):
delta = cic_paint(jnp.zeros(mesh_shape), pos)
alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta)
PGD_range=PGD_kernel(kvec, kl, ks)
PGD_range = PGD_kernel(kvec, kl, ks)
pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
for i in range(3)],axis=-1)
forces_pgd = jnp.stack([
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
for i in range(3)
],
axis=-1)
dpos_pgd = forces_pgd*alpha
dpos_pgd = forces_pgd * alpha
return dpos_pgd

View file

@ -1,9 +1,10 @@
import numpy as np
import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm
__all__ = ['power_spectrum']
def _initialize_pk(shape, boxsize, kmin, dk):
"""
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
@ -63,12 +64,12 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
#absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image))
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
@ -81,6 +82,7 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
return kbins, P / norm
def gaussian_smoothing(im, sigma):
"""
im: 2d image
@ -93,7 +95,6 @@ def gaussian_smoothing(im, sigma):
k = jnp.linalg.norm(kvec, axis=-1)
# We compute the value of the filter at frequency k
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
filter /= filter[0,0]
filter /= filter[0, 0]
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real

View file

@ -1,4 +1,4 @@
from setuptools import setup, find_packages
from setuptools import find_packages, setup
setup(
name='JaxPM',