remove deprecated stuff

This commit is contained in:
Francois Lanusse 2024-10-24 16:36:41 -04:00 committed by GitHub
parent ef7a7ef5c9
commit f57e32af7f
3 changed files with 0 additions and 392 deletions

View file

@ -1,292 +0,0 @@
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.experimental.pjit import PartitionSpec, pjit
import jaxpm.painting as paint
# TODO: add a way to configure axis resources from command line
axis_resources = {'x': 'nx', 'y': 'ny'}
mesh_size = {'nx': 2, 'ny': 2}
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def stack3d(a, b, c):
return jnp.stack([a, b, c], axis=-1)
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, [...]),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def scalar_multiply(a, factor):
return a * factor
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}, {
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def add(a, b):
return a + b
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def fft3d(mesh):
""" Performs a 3D complex Fourier transform
Args:
mesh: a real 3D tensor of shape [Nx, Ny, Nz]
Returns:
3D FFT of the input, note that the dimensions of the output
are tranposed.
"""
mesh = jnp.fft.fft(mesh)
mesh = lax.all_to_all(mesh, 'x', 0, 0)
mesh = jnp.fft.fft(mesh)
mesh = lax.all_to_all(mesh, 'y', 0, 0)
return jnp.fft.fft(mesh) # Note the output is transposed # [z, x, y]
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
def ifft3d(mesh):
mesh = jnp.fft.ifft(mesh)
mesh = lax.all_to_all(mesh, 'y', 0, 0)
mesh = jnp.fft.ifft(mesh)
mesh = lax.all_to_all(mesh, 'x', 0, 0)
return jnp.fft.ifft(mesh).real
def normal(key, shape=[]):
@partial(xmap,
in_axes=['x', 'y', ...],
out_axes={
0: 'x',
2: 'y'
},
axis_resources=axis_resources)
def fn(key):
""" Generate a distributed random normal distributions
Args:
key: array of random keys with same layout as computational mesh
shape: logical shape of array to sample
"""
return jax.random.normal(
key,
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
shape[2:])
return fn(key)
@partial(xmap,
in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
@jax.jit
def scale_by_power_spectrum(kfield, kvec, k, pk):
kx, ky, kz = kvec
kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
@partial(xmap,
in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
out_axes=(['x', 'y', 'z']),
axis_resources=axis_resources)
def gradient_laplace_kernel(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1. / kk)
return (kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
@partial(xmap,
in_axes=([...]),
out_axes={
0: 'x',
2: 'y'
},
axis_sizes={
'x': mesh_size['nx'],
'y': mesh_size['ny']
},
axis_resources=axis_resources)
def meshgrid(x, y, z):
""" Generates a mesh grid of appropriate size for the
computational mesh we have.
"""
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
def cic_paint(pos, mesh_shape, halo_size=0):
@partial(xmap,
in_axes=({
0: 'x',
2: 'y'
}),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(pos):
mesh = jnp.zeros([
mesh_shape[0] // mesh_size['nx'] +
2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
] + mesh_shape[2:])
# Paint particles
mesh = paint.cic_paint(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
# Perform halo exchange
# Halo exchange along x
left = lax.pshuffle(mesh[-2 * halo_size:],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
right = lax.pshuffle(mesh[:2 * halo_size],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
mesh = mesh.at[:2 * halo_size].add(left)
mesh = mesh.at[-2 * halo_size:].add(right)
# Halo exchange along y
left = lax.pshuffle(mesh[:, -2 * halo_size:],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
right = lax.pshuffle(mesh[:, :2 * halo_size],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
mesh = mesh.at[:, :2 * halo_size].add(left)
mesh = mesh.at[:, -2 * halo_size:].add(right)
# removing halo and returning mesh
return mesh[halo_size:-halo_size, halo_size:-halo_size]
return fn(pos)
def cic_read(mesh, pos, halo_size=0):
@partial(xmap,
in_axes=(
{
0: 'x',
2: 'y'
},
{
0: 'x',
2: 'y'
},
),
out_axes=({
0: 'x',
2: 'y'
}),
axis_resources=axis_resources)
def fn(mesh, pos):
# Halo exchange to grab neighboring borders
# Exchange along x
left = lax.pshuffle(mesh[-halo_size:],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
right = lax.pshuffle(mesh[:halo_size],
perm=range(mesh_size['nx'])[::-1],
axis_name='x')
mesh = jnp.concatenate([left, mesh, right], axis=0)
# Exchange along y
left = lax.pshuffle(mesh[:, -halo_size:],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
right = lax.pshuffle(mesh[:, :halo_size],
perm=range(mesh_size['ny'])[::-1],
axis_name='y')
mesh = jnp.concatenate([left, mesh, right], axis=1)
# Reading field at particles positions
res = paint.cic_read(
mesh,
pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
return res.reshape(pos.shape[:-1])
return fn(mesh, pos)
@partial(pjit,
in_axis_resources=PartitionSpec('nx', 'ny'),
out_axis_resources=PartitionSpec('nx', None, 'ny', None))
def reshape_dense_to_split(x):
""" Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([
mesh_size['nx'], shape[0] //
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
] + shape[2:])
@partial(pjit,
in_axis_resources=PartitionSpec('nx', None, 'ny', None),
out_axis_resources=PartitionSpec('nx', 'ny'))
def reshape_split_to_dense(x):
""" Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
Changes the logical shape of the array, but no shuffling of the
data should be necessary
"""
shape = list(x.shape)
return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])

View file

@ -1,100 +0,0 @@
from functools import partial
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.maps import xmap
from jax.lax import linear_solve_p
import jaxpm.experimental.distributed_ops as dops
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import fftk
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
"""
Computes gravitational forces on particles using a PM scheme
"""
if mesh_shape is None:
mesh_shape = delta_k.shape
kvec = [k.squeeze() for k in fftk(mesh_shape, symmetric=False)]
if delta_k is None:
delta = dops.cic_paint(positions, mesh_shape, halo_size)
delta_k = dops.fft3d(dops.reshape_split_to_dense(delta))
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
# Recovers forces at particle positions
forces = [
dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
halo_size) for f in forces_k
]
return dops.stack3d(*forces)
def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
"""
Generate initial conditions.
Seed should have the dimension of the computational mesh
"""
# Sample normal field
field = dops.normal(seed, shape=mesh_shape)
# Go to Fourier space
field = dops.fft3d(dops.reshape_split_to_dense(field))
# Rescaling k to physical units
kvec = [
k.squeeze() / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(mesh_shape, symmetric=False))
]
k = jnp.logspace(-4, 2, 256)
pk = jc.power.linear_matter_power(cosmo, k)
pk = pk * (mesh_shape[0] * mesh_shape[1] *
mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
if return_Fourier:
return field
else:
return dops.reshape_dense_to_split(dops.ifft3d(field))
def lpt(cosmo, initial_conditions, positions, a):
"""
Computes first order LPT displacement
"""
initial_force = pm_forces(positions, delta_k=initial_conditions)
a = jnp.atleast_1d(a)
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
p = dops.scalar_multiply(
dx,
a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
return dx, p
def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = dops.scalar_multiply(
vel, 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
# Computes the update of velocity (kick)
dvel = dops.scalar_multiply(
forces, 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
return dpos, dvel
return nbody_ode