forked from guilhem_lavaux/JaxPM
Applying formatting
This commit is contained in:
parent
835fa89aec
commit
f28442bb48
14 changed files with 565 additions and 445 deletions
|
@ -1,11 +1,12 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.lax as lax
|
||||
from functools import partial
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit, PartitionSpec
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
|
||||
import jaxpm.painting as paint
|
||||
|
||||
# TODO: add a way to configure axis resources from command line
|
||||
|
@ -14,35 +15,59 @@ mesh_size = {'nx': 2, 'ny': 2}
|
|||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=({0: 'x', 2: 'y'},
|
||||
{0: 'x', 2: 'y'},
|
||||
{0: 'x', 2: 'y'}),
|
||||
out_axes=({0: 'x', 2: 'y'}),
|
||||
in_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}, {
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}, {
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
out_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
axis_resources=axis_resources)
|
||||
def stack3d(a, b, c):
|
||||
return jnp.stack([a, b, c], axis=-1)
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=({0: 'x', 2: 'y'},[...]),
|
||||
out_axes=({0: 'x', 2: 'y'}),
|
||||
in_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}, [...]),
|
||||
out_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
axis_resources=axis_resources)
|
||||
def scalar_multiply(a, factor):
|
||||
return a * factor
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=({0: 'x', 2: 'y'},
|
||||
{0: 'x', 2: 'y'}),
|
||||
out_axes=({0: 'x', 2: 'y'}),
|
||||
in_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}, {
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
out_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
axis_resources=axis_resources)
|
||||
def add(a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=['x', 'y',...],
|
||||
out_axes=['x', 'y',...],
|
||||
in_axes=['x', 'y', ...],
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources=axis_resources)
|
||||
def fft3d(mesh):
|
||||
""" Performs a 3D complex Fourier transform
|
||||
|
@ -51,7 +76,7 @@ def fft3d(mesh):
|
|||
mesh: a real 3D tensor of shape [Nx, Ny, Nz]
|
||||
|
||||
Returns:
|
||||
3D FFT of the input, note that the dimensions of the output
|
||||
3D FFT of the input, note that the dimensions of the output
|
||||
are tranposed.
|
||||
"""
|
||||
mesh = jnp.fft.fft(mesh)
|
||||
|
@ -62,8 +87,8 @@ def fft3d(mesh):
|
|||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=['x', 'y',...],
|
||||
out_axes=['x', 'y',...],
|
||||
in_axes=['x', 'y', ...],
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources=axis_resources)
|
||||
def ifft3d(mesh):
|
||||
mesh = jnp.fft.ifft(mesh)
|
||||
|
@ -72,10 +97,15 @@ def ifft3d(mesh):
|
|||
mesh = lax.all_to_all(mesh, 'x', 0, 0)
|
||||
return jnp.fft.ifft(mesh).real
|
||||
|
||||
|
||||
def normal(key, shape=[]):
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=['x', 'y',...],
|
||||
out_axes={0: 'x', 2: 'y'},
|
||||
in_axes=['x', 'y', ...],
|
||||
out_axes={
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
},
|
||||
axis_resources=axis_resources)
|
||||
def fn(key):
|
||||
""" Generate a distributed random normal distributions
|
||||
|
@ -83,99 +113,126 @@ def normal(key, shape=[]):
|
|||
key: array of random keys with same layout as computational mesh
|
||||
shape: logical shape of array to sample
|
||||
"""
|
||||
return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'],
|
||||
shape[1]//mesh_size['ny']]+shape[2:])
|
||||
return jax.random.normal(
|
||||
key,
|
||||
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
|
||||
shape[2:])
|
||||
|
||||
return fn(key)
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=(['x', 'y', ...],
|
||||
[['x'], ['y'], [...]], [...], [...]),
|
||||
in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources=axis_resources)
|
||||
@jax.jit
|
||||
def scale_by_power_spectrum(kfield, kvec, k, pk):
|
||||
kx, ky, kz = kvec
|
||||
kk = jnp.sqrt(kx**2 + ky ** 2 + kz**2)
|
||||
kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
|
||||
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=(['x', 'y', 'z'],
|
||||
[['x'], ['y'], ['z']]),
|
||||
in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
|
||||
out_axes=(['x', 'y', 'z']),
|
||||
axis_resources=axis_resources)
|
||||
def gradient_laplace_kernel(kfield, kvec):
|
||||
kx, ky, kz = kvec
|
||||
kk = (kx**2 + ky**2 + kz**2)
|
||||
kernel = jnp.where(kk == 0, 1., 1./kk)
|
||||
return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
|
||||
kfield * kernel * 1j * 1 / 6.0 *
|
||||
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
||||
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
|
||||
kernel = jnp.where(kk == 0, 1., 1. / kk)
|
||||
return (kfield * kernel * 1j * 1 / 6.0 *
|
||||
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
|
||||
6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
|
||||
1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=([...]),
|
||||
out_axes={0: 'x', 2: 'y'},
|
||||
axis_sizes={'x': mesh_size['nx'],
|
||||
'y': mesh_size['ny']},
|
||||
out_axes={
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
},
|
||||
axis_sizes={
|
||||
'x': mesh_size['nx'],
|
||||
'y': mesh_size['ny']
|
||||
},
|
||||
axis_resources=axis_resources)
|
||||
def meshgrid(x, y, z):
|
||||
""" Generates a mesh grid of appropriate size for the
|
||||
""" Generates a mesh grid of appropriate size for the
|
||||
computational mesh we have.
|
||||
"""
|
||||
return jnp.stack(jnp.meshgrid(x,
|
||||
y,
|
||||
z), axis=-1)
|
||||
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
|
||||
|
||||
|
||||
def cic_paint(pos, mesh_shape, halo_size=0):
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=({0: 'x', 2: 'y'}),
|
||||
out_axes=({0: 'x', 2: 'y'}),
|
||||
in_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
out_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
axis_resources=axis_resources)
|
||||
def fn(pos):
|
||||
|
||||
mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size,
|
||||
mesh_shape[1]//mesh_size['ny']+2*halo_size]
|
||||
+ mesh_shape[2:])
|
||||
mesh = jnp.zeros([
|
||||
mesh_shape[0] // mesh_size['nx'] +
|
||||
2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
|
||||
] + mesh_shape[2:])
|
||||
|
||||
# Paint particles
|
||||
mesh = paint.cic_paint(mesh, pos.reshape(-1, 3) +
|
||||
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||
mesh = paint.cic_paint(
|
||||
mesh,
|
||||
pos.reshape(-1, 3) +
|
||||
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||
|
||||
# Perform halo exchange
|
||||
# Halo exchange along x
|
||||
left = lax.pshuffle(mesh[-2*halo_size:],
|
||||
left = lax.pshuffle(mesh[-2 * halo_size:],
|
||||
perm=range(mesh_size['nx'])[::-1],
|
||||
axis_name='x')
|
||||
right = lax.pshuffle(mesh[:2*halo_size],
|
||||
right = lax.pshuffle(mesh[:2 * halo_size],
|
||||
perm=range(mesh_size['nx'])[::-1],
|
||||
axis_name='x')
|
||||
mesh = mesh.at[:2*halo_size].add(left)
|
||||
mesh = mesh.at[-2*halo_size:].add(right)
|
||||
mesh = mesh.at[:2 * halo_size].add(left)
|
||||
mesh = mesh.at[-2 * halo_size:].add(right)
|
||||
|
||||
# Halo exchange along y
|
||||
left = lax.pshuffle(mesh[:, -2*halo_size:],
|
||||
left = lax.pshuffle(mesh[:, -2 * halo_size:],
|
||||
perm=range(mesh_size['ny'])[::-1],
|
||||
axis_name='y')
|
||||
right = lax.pshuffle(mesh[:, :2*halo_size],
|
||||
right = lax.pshuffle(mesh[:, :2 * halo_size],
|
||||
perm=range(mesh_size['ny'])[::-1],
|
||||
axis_name='y')
|
||||
mesh = mesh.at[:, :2*halo_size].add(left)
|
||||
mesh = mesh.at[:, -2*halo_size:].add(right)
|
||||
mesh = mesh.at[:, :2 * halo_size].add(left)
|
||||
mesh = mesh.at[:, -2 * halo_size:].add(right)
|
||||
|
||||
# removing halo and returning mesh
|
||||
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||
|
||||
return fn(pos)
|
||||
|
||||
|
||||
def cic_read(mesh, pos, halo_size=0):
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=({0: 'x', 2: 'y'},
|
||||
{0: 'x', 2: 'y'},),
|
||||
out_axes=({0: 'x', 2: 'y'}),
|
||||
in_axes=(
|
||||
{
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
},
|
||||
{
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
},
|
||||
),
|
||||
out_axes=({
|
||||
0: 'x',
|
||||
2: 'y'
|
||||
}),
|
||||
axis_resources=axis_resources)
|
||||
def fn(mesh, pos):
|
||||
|
||||
|
@ -198,11 +255,13 @@ def cic_read(mesh, pos, halo_size=0):
|
|||
mesh = jnp.concatenate([left, mesh, right], axis=1)
|
||||
|
||||
# Reading field at particles positions
|
||||
res = paint.cic_read(mesh, pos.reshape(-1, 3) +
|
||||
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||
res = paint.cic_read(
|
||||
mesh,
|
||||
pos.reshape(-1, 3) +
|
||||
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||
|
||||
return res.reshape(pos.shape[:-1])
|
||||
|
||||
|
||||
return fn(mesh, pos)
|
||||
|
||||
|
||||
|
@ -211,12 +270,14 @@ def cic_read(mesh, pos, halo_size=0):
|
|||
out_axis_resources=PartitionSpec('nx', None, 'ny', None))
|
||||
def reshape_dense_to_split(x):
|
||||
""" Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
|
||||
Changes the logical shape of the array, but no shuffling of the
|
||||
Changes the logical shape of the array, but no shuffling of the
|
||||
data should be necessary
|
||||
"""
|
||||
shape = list(x.shape)
|
||||
return x.reshape([mesh_size['nx'], shape[0]//mesh_size['nx'],
|
||||
mesh_size['ny'], shape[2]//mesh_size['ny']] + shape[2:])
|
||||
return x.reshape([
|
||||
mesh_size['nx'], shape[0] //
|
||||
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
|
||||
] + shape[2:])
|
||||
|
||||
|
||||
@partial(pjit,
|
||||
|
@ -224,8 +285,8 @@ def reshape_dense_to_split(x):
|
|||
out_axis_resources=PartitionSpec('nx', 'ny'))
|
||||
def reshape_split_to_dense(x):
|
||||
""" Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
|
||||
Changes the logical shape of the array, but no shuffling of the
|
||||
Changes the logical shape of the array, but no shuffling of the
|
||||
data should be necessary
|
||||
"""
|
||||
shape = list(x.shape)
|
||||
return x.reshape([shape[0]*shape[1], shape[2]*shape[3]] + shape[4:])
|
||||
return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import jax
|
||||
from jax.lax import linear_solve_p
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.maps import xmap
|
||||
from functools import partial
|
||||
import jax_cosmo as jc
|
||||
|
||||
from jaxpm.kernels import fftk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.lax import linear_solve_p
|
||||
|
||||
import jaxpm.experimental.distributed_ops as dops
|
||||
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
||||
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
||||
from jaxpm.kernels import fftk
|
||||
|
||||
|
||||
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
|
||||
|
@ -25,8 +26,10 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
|
|||
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
|
||||
|
||||
# Recovers forces at particle positions
|
||||
forces = [dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)),
|
||||
positions, halo_size) for f in forces_k]
|
||||
forces = [
|
||||
dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
|
||||
halo_size) for f in forces_k
|
||||
]
|
||||
|
||||
return dops.stack3d(*forces)
|
||||
|
||||
|
@ -44,12 +47,14 @@ def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
|
|||
field = dops.fft3d(dops.reshape_split_to_dense(field))
|
||||
|
||||
# Rescaling k to physical units
|
||||
kvec = [k.squeeze() / box_size[i] * mesh_shape[i]
|
||||
for i, k in enumerate(fftk(mesh_shape, symmetric=False))]
|
||||
kvec = [
|
||||
k.squeeze() / box_size[i] * mesh_shape[i]
|
||||
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
|
||||
]
|
||||
k = jnp.logspace(-4, 2, 256)
|
||||
pk = jc.power.linear_matter_power(cosmo, k)
|
||||
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
|
||||
) / (box_size[0] * box_size[1] * box_size[2])
|
||||
pk = pk * (mesh_shape[0] * mesh_shape[1] *
|
||||
mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
|
||||
|
||||
|
@ -66,8 +71,9 @@ def lpt(cosmo, initial_conditions, positions, a):
|
|||
initial_force = pm_forces(positions, delta_k=initial_conditions)
|
||||
a = jnp.atleast_1d(a)
|
||||
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
|
||||
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
|
||||
jnp.sqrt(jc.background.Esqr(cosmo, a)))
|
||||
p = dops.scalar_multiply(
|
||||
dx,
|
||||
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
|
||||
return dx, p
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue