Applying formatting

This commit is contained in:
EiffL 2024-07-09 14:54:34 -04:00
parent 835fa89aec
commit f28442bb48
14 changed files with 565 additions and 445 deletions

View file

@ -1,11 +1,12 @@
import jax
import jax.numpy as jnp
import jax.lax as lax
from functools import partial
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit, PartitionSpec
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.experimental.pjit import PartitionSpec, pjit
import jaxpm.painting as paint
# TODO: add a way to configure axis resources from command line
@ -14,35 +15,59 @@ mesh_size = {'nx': 2, 'ny': 2}
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},
{0: 'x', 2: 'y'},
{0: 'x', 2: 'y'}),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def stack3d(a, b, c):
return jnp.stack([a, b, c], axis=-1)
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},[...]),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}, [...]),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def scalar_multiply(a, factor):
return a * factor
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},
{0: 'x', 2: 'y'}),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def add(a, b):
return a + b
@partial(xmap,
in_axes=['x', 'y',...],
out_axes=['x', 'y',...],
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def fft3d(mesh):
""" Performs a 3D complex Fourier transform
@ -51,7 +76,7 @@ def fft3d(mesh):
mesh: a real 3D tensor of shape [Nx, Ny, Nz]
Returns:
3D FFT of the input, note that the dimensions of the output
3D FFT of the input, note that the dimensions of the output
are tranposed.
"""
mesh = jnp.fft.fft(mesh)
@ -62,8 +87,8 @@ def fft3d(mesh):
@partial(xmap,
in_axes=['x', 'y',...],
out_axes=['x', 'y',...],
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def ifft3d(mesh):
mesh = jnp.fft.ifft(mesh)
@ -72,10 +97,15 @@ def ifft3d(mesh):
mesh = lax.all_to_all(mesh, 'x', 0, 0)
return jnp.fft.ifft(mesh).real
def normal(key, shape=[]):
@partial(xmap,
in_axes=['x', 'y',...],
out_axes={0: 'x', 2: 'y'},
in_axes=['x', 'y', ...],
out_axes={
0: 'x',
2: 'y'
},
axis_resources=axis_resources)
def fn(key):
""" Generate a distributed random normal distributions
@ -83,99 +113,126 @@ def normal(key, shape=[]):
key: array of random keys with same layout as computational mesh
shape: logical shape of array to sample
"""
return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'],
shape[1]//mesh_size['ny']]+shape[2:])
return jax.random.normal(
key,
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
shape[2:])
return fn(key)
@partial(xmap,
in_axes=(['x', 'y', ...],
[['x'], ['y'], [...]], [...], [...]),
in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
@jax.jit
def scale_by_power_spectrum(kfield, kvec, k, pk):
kx, ky, kz = kvec
kk = jnp.sqrt(kx**2 + ky ** 2 + kz**2)
kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
@partial(xmap,
in_axes=(['x', 'y', 'z'],
[['x'], ['y'], ['z']]),
in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
out_axes=(['x', 'y', 'z']),
axis_resources=axis_resources)
def gradient_laplace_kernel(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk)
return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
kernel = jnp.where(kk == 0, 1., 1. / kk)
return (kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
@partial(xmap,
in_axes=([...]),
out_axes={0: 'x', 2: 'y'},
axis_sizes={'x': mesh_size['nx'],
'y': mesh_size['ny']},
out_axes={
0: 'x',
2: 'y'
},
axis_sizes={
'x': mesh_size['nx'],
'y': mesh_size['ny']
},
axis_resources=axis_resources)
def meshgrid(x, y, z):
""" Generates a mesh grid of appropriate size for the
""" Generates a mesh grid of appropriate size for the
computational mesh we have.
"""
return jnp.stack(jnp.meshgrid(x,
y,
z), axis=-1)
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
def cic_paint(pos, mesh_shape, halo_size=0):
@partial(xmap,
in_axes=({0: 'x', 2: 'y'}),
out_axes=({0: 'x', 2: 'y'}),
in_axes=({
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(pos):
mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size,
mesh_shape[1]//mesh_size['ny']+2*halo_size]
+ mesh_shape[2:])
mesh = jnp.zeros([
mesh_shape[0] // mesh_size['nx'] +
2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
] + mesh_shape[2:])
# Paint particles
mesh = paint.cic_paint(mesh, pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
mesh = paint.cic_paint(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
# Perform halo exchange
# Halo exchange along x
left = lax.pshuffle(mesh[-2*halo_size:],
left = lax.pshuffle(mesh[-2 * halo_size:],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
right = lax.pshuffle(mesh[:2*halo_size],
right = lax.pshuffle(mesh[:2 * halo_size],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
mesh = mesh.at[:2*halo_size].add(left)
mesh = mesh.at[-2*halo_size:].add(right)
mesh = mesh.at[:2 * halo_size].add(left)
mesh = mesh.at[-2 * halo_size:].add(right)
# Halo exchange along y
left = lax.pshuffle(mesh[:, -2*halo_size:],
left = lax.pshuffle(mesh[:, -2 * halo_size:],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
right = lax.pshuffle(mesh[:, :2*halo_size],
right = lax.pshuffle(mesh[:, :2 * halo_size],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
mesh = mesh.at[:, :2*halo_size].add(left)
mesh = mesh.at[:, -2*halo_size:].add(right)
mesh = mesh.at[:, :2 * halo_size].add(left)
mesh = mesh.at[:, -2 * halo_size:].add(right)
# removing halo and returning mesh
return mesh[halo_size:-halo_size, halo_size:-halo_size]
return fn(pos)
def cic_read(mesh, pos, halo_size=0):
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},
{0: 'x', 2: 'y'},),
out_axes=({0: 'x', 2: 'y'}),
in_axes=(
{
0: 'x',
2: 'y'
},
{
0: 'x',
2: 'y'
},
),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(mesh, pos):
@ -198,11 +255,13 @@ def cic_read(mesh, pos, halo_size=0):
mesh = jnp.concatenate([left, mesh, right], axis=1)
# Reading field at particles positions
res = paint.cic_read(mesh, pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
res = paint.cic_read(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
return res.reshape(pos.shape[:-1])
return fn(mesh, pos)
@ -211,12 +270,14 @@ def cic_read(mesh, pos, halo_size=0):
out_axis_resources=PartitionSpec('nx', None, 'ny', None))
def reshape_dense_to_split(x):
""" Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
Changes the logical shape of the array, but no shuffling of the
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([mesh_size['nx'], shape[0]//mesh_size['nx'],
mesh_size['ny'], shape[2]//mesh_size['ny']] + shape[2:])
return x.reshape([
mesh_size['nx'], shape[0] //
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
] + shape[2:])
@partial(pjit,
@ -224,8 +285,8 @@ def reshape_dense_to_split(x):
out_axis_resources=PartitionSpec('nx', 'ny'))
def reshape_split_to_dense(x):
""" Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
Changes the logical shape of the array, but no shuffling of the
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([shape[0]*shape[1], shape[2]*shape[3]] + shape[4:])
return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])

View file

@ -1,13 +1,14 @@
import jax
from jax.lax import linear_solve_p
import jax.numpy as jnp
from jax.experimental.maps import xmap
from functools import partial
import jax_cosmo as jc
from jaxpm.kernels import fftk
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.lax import linear_solve_p
import jaxpm.experimental.distributed_ops as dops
from jaxpm.growth import growth_factor, growth_rate, dGfa
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import fftk
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
@ -25,8 +26,10 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
# Recovers forces at particle positions
forces = [dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)),
positions, halo_size) for f in forces_k]
forces = [
dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
halo_size) for f in forces_k
]
return dops.stack3d(*forces)
@ -44,12 +47,14 @@ def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
field = dops.fft3d(dops.reshape_split_to_dense(field))
# Rescaling k to physical units
kvec = [k.squeeze() / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(mesh_shape, symmetric=False))]
kvec = [
k.squeeze() / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
]
k = jnp.logspace(-4, 2, 256)
pk = jc.power.linear_matter_power(cosmo, k)
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
) / (box_size[0] * box_size[1] * box_size[2])
pk = pk * (mesh_shape[0] * mesh_shape[1] *
mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
@ -66,8 +71,9 @@ def lpt(cosmo, initial_conditions, positions, a):
initial_force = pm_forces(positions, delta_k=initial_conditions)
a = jnp.atleast_1d(a)
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
jnp.sqrt(jc.background.Esqr(cosmo, a)))
p = dops.scalar_multiply(
dx,
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
return dx, p