JaxPM/jaxpm/pm.py
Wassim KABALAN df8602b318 jaxdecomp proto (#21)
* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
2024-12-20 05:44:02 -05:00

268 lines
8.8 KiB
Python

import jax.numpy as jnp
import jax_cosmo as jc
from jaxpm.distributed import fft3d, ifft3d, normal_field
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
growth_rate, growth_rate_second)
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
invlaplace_kernel, longrange_kernel)
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
def pm_forces(positions,
mesh_shape=None,
delta=None,
r_split=0,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
"""
Computes gravitational forces on particles using a PM scheme
"""
if mesh_shape is None:
assert (delta is not None),\
"If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape
if paint_absolute_pos:
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape,
device=sharding),
pos,
halo_size=halo_size,
sharding=sharding)
read_fn = lambda grid_mesh, pos: cic_read(
grid_mesh, pos, halo_size=halo_size, sharding=sharding)
else:
paint_fn = lambda disp: cic_paint_dx(
disp, halo_size=halo_size, sharding=sharding)
read_fn = lambda grid_mesh, disp: cic_read_dx(
grid_mesh, disp, halo_size=halo_size, sharding=sharding)
if delta is None:
field = paint_fn(positions)
delta_k = fft3d(field)
elif jnp.isrealobj(delta):
delta_k = fft3d(delta)
else:
delta_k = delta
kvec = fftk(delta_k)
# Computes gravitational potential
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
kvec, r_split=r_split)
# Computes gravitational forces
forces = jnp.stack([
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
) for i in range(3)], axis=-1) # yapf: disable
return forces
def lpt(cosmo,
initial_conditions,
particles=None,
a=0.1,
halo_size=0,
sharding=None,
order=1):
"""
Computes first and second order LPT displacement and momentum,
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
"""
paint_absolute_pos = particles is not None
if particles is None:
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
delta_k = fft3d(initial_conditions)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * E * dx
f = a**2 * E * dGfa(cosmo, a) * initial_force
if order == 2:
kvec = fftk(delta_k)
pot_k = delta_k * invlaplace_kernel(kvec)
delta2 = 0
shear_acc = 0
# for i, ki in enumerate(kvec):
for i in range(3):
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
delta2 += shear_ii * shear_acc
shear_acc += shear_ii
# for kj in kvec[i+1:]:
for j in range(i + 1, 3):
# Substract squared strict-up-triangle terms
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
kvec, j)
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
delta_k2 = fft3d(delta2)
init_force2 = pm_forces(particles,
delta=delta_k2,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
f2 = a**2 * E * dGf2a(cosmo, a) * init_force2
dx += dx2
p += p2
f += f2
return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
"""
Generate initial conditions.
"""
# Initialize a random field with one slice on each gpu
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
field = fft3d(field)
kvec = fftk(field)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
field = field * (pkmesh)**0.5
field = ifft3d(field)
return field
def make_ode_fn(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(state, a, cosmo):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
forces = pm_forces(pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel
return nbody_ode
def make_diffrax_ode(cosmo,
mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(a, state, args):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
forces = pm_forces(pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return jnp.stack([dpos, dvel])
return nbody_ode
def pgd_correction(pos, mesh_shape, params):
"""
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
based on https://arxiv.org/abs/1804.00671
args:
pos: particle positions [npart, 3]
params: [alpha, kl, ks] pgd parameters
"""
delta = cic_paint(jnp.zeros(mesh_shape), pos)
delta_k = fft3d(delta)
kvec = fftk(delta_k)
alpha, kl, ks = params
PGD_range = PGD_kernel(kvec, kl, ks)
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
forces_pgd = jnp.stack([
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
for i in range(3)
],
axis=-1)
dpos_pgd = forces_pgd * alpha
return dpos_pgd
def make_neural_ode_fn(model, mesh_shape):
def neural_nbody_ode(state, a, cosmo: Cosmology, params):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
delta = cic_paint(jnp.zeros(mesh_shape), pos)
delta_k = fft3d(delta)
kvec = fftk(delta_k)
# Computes gravitational potential
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
r_split=0)
# Apply a correction filter
kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
# Computes gravitational forces
forces = jnp.stack([
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos)
for i in range(3)
],
axis=-1)
forces = forces * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel
return neural_nbody_ode