From 1948eae9edc7e852fd4d4e154b98ceebefe4b7de Mon Sep 17 00:00:00 2001 From: EiffL Date: Sat, 22 Oct 2022 11:30:25 -0500 Subject: [PATCH] Adding begnning of implem --- jaxpm/kernels.py | 100 ++++++++++++++--------------------------------- jaxpm/ops.py | 16 +++++++- jaxpm/pm.py | 39 +++++++++++------- 3 files changed, 68 insertions(+), 87 deletions(-) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 97d34dd..61b9e58 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,88 +1,46 @@ import numpy as np import jax.numpy as jnp -def fftk(shape, symmetric=True, finite=False, dtype=np.float32): - """ Return k_vector given a shape (nc, nc, nc) and box_size +def fftk(shape, symmetric=True, dtype=np.float32, comms=None): + """ Return k_vector given a shape (nc, nc, nc) """ k = [] + + if comms is not None: + nx = comms[0].Get_size() + ix = comms[0].Get_rank() + ny = comms[1].Get_size() + iy = comms[1].Get_rank() + shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:]) + for d in range(len(shape)): kd = np.fft.fftfreq(shape[d]) kd *= 2 * np.pi - kdshape = np.ones(len(shape), dtype='int') + if symmetric and d == len(shape) - 1: kd = kd[:shape[d] // 2 + 1] - kdshape[d] = len(kd) - kd = kd.reshape(kdshape) + + if (comms is not None) and d==0: + kd = kd.reshape([nx, -1])[ix] + + if (comms is not None) and d==1: + kd = kd.reshape([ny, -1])[iy] k.append(kd.astype(dtype)) - del kd, kdshape return k -def gradient_kernel(kvec, direction, order=1): - """ - Computes the gradient kernel in the requested direction - Parameters: - ----------- - kvec: array - Array of k values in Fourier space - direction: int - Index of the direction in which to take the gradient - Returns: - -------- - wts: array - Complex kernel - """ - if order == 0: - wts = 1j * kvec[direction] - wts = jnp.squeeze(wts) - wts[len(wts) // 2] = 0 - wts = wts.reshape(kvec[direction].shape) - return wts - else: - w = kvec[direction] - a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) - wts = a * 1j - return wts - -def laplace_kernel(kvec): - """ - Compute the Laplace kernel from a given K vector - Parameters: - ----------- - kvec: array - Array of k values in Fourier space - Returns: - -------- - wts: array - Complex kernel - """ - kk = sum(ki**2 for ki in kvec) - mask = (kk == 0).nonzero() - kk[mask] = 1 - wts = 1. / kk - imask = (~(kk == 0)).astype(int) - wts *= imask - return wts - -def longrange_kernel(kvec, r_split): - """ - Computes a long range kernel - Parameters: - ----------- - kvec: array - Array of k values in Fourier space - r_split: float - TODO: @modichirag add documentation - Returns: - -------- - wts: array - kernel - """ - if r_split != 0: - kk = sum(ki**2 for ki in kvec) - return np.exp(-kk * r_split**2) - else: - return 1. +@partial(jax.pmap, + in_axes=[['x','y','z'], + ['x'],['y'],['z']], + out_axes=['x','y','z',...]) +def apply_gradient_laplace(kfield, kvec): + kx, ky, kz = kvec + kk = (kx**2 + ky**2 + kz**2) + kernel = jnp.where(kk == 0, 1., 1./kk) + return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), + kfield * kernel * 1j * 1 / 6.0 * + (8 * jnp.sin(kz) - jnp.sin(2 * kz)), + kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))],axis=-1) def cic_compensation(kvec): """ diff --git a/jaxpm/ops.py b/jaxpm/ops.py index b4c1801..68bf355 100644 --- a/jaxpm/ops.py +++ b/jaxpm/ops.py @@ -100,12 +100,24 @@ def halo_reduce(arr, halo_size, token=None, comms=None): rank_y = comms[1].Get_rank() margin = arr[:, -2*halo_size:] margin, token = mpi4jax.sendrecv(margin, margin, rank_y-1, rank_y+1, - comm=comms[0], token=token) + comm=comms[1], token=token) arr = arr.at[:, :2*halo_size].add(margin) margin = arr[:, :2*halo_size] margin, token = mpi4jax.sendrecv(margin, margin, rank_y+1, rank_y-1, - comm=comms[0], token=token) + comm=comms[1], token=token) arr = arr.at[:, -2*halo_size:].add(margin) return arr, token + +def zeros(shape, comms=None): + """ Initialize an array of given global shape + partitionned if need be accross dimensions. + """ + if comms is None: + return jnp.zeros(shape) + + nx = comms[0].Get_size() + ny = comms[1].Get_size() + + return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:])) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index d9870f7..3fcc760 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -3,40 +3,51 @@ import jax.numpy as jnp import jax_cosmo as jc -from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel, PGD_kernel +from jaxpm.ops import fft3d, ifft3d, zeros +from jaxpm.kernels import fftk, apply_gradient_laplace from jaxpm.painting import cic_paint, cic_read from jaxpm.growth import growth_factor, growth_rate, dGfa -def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): +def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None): """ Computes gravitational forces on particles using a PM scheme """ if mesh_shape is None: - mesh_shape = delta.shape - kvec = fftk(mesh_shape) + mesh_shape = delta_k.shape - if delta is None: - delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions)) - else: - delta_k = jnp.fft.rfftn(delta) + kvec = fftk(mesh_shape, comms=comms) + + if delta_k is None: + delta, token = cic_paint(zeros(mesh_shape,comms=comms), + positions, + halo_size=halo_size, token=token, comms=comms) + delta_k, token = fft3d(delta, token=token, comms=comms) # Computes gravitational potential - pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) + forces_k = apply_gradient_laplace(kfield, kvec) + # Computes gravitational forces - return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions) - for i in range(3)],axis=-1) + fx, token = ifft3d(forces_k[...,0], token=token, comms=comms) + fx, token = cic_read(fx, positions, halo_size=halo_size, comms=comms) + fy, token = ifft3d(forces_k[...,1], token=token, comms=comms) + fy, token = cic_read(fy, positions, halo_size=halo_size, comms=comms) -def lpt(cosmo, initial_conditions, positions, a): + fz, token = ifft3d(forces_k[...,2], token=token, comms=comms) + fz, token = cic_read(fz, positions, halo_size=halo_size, comms=comms) + + return jnp.stack([fx,fy,fz],axis=-1), token + +def lpt(cosmo, initial_conditions, positions, a, token=token, comms=comms): """ Computes first order LPT displacement """ - initial_force = pm_forces(positions, delta=initial_conditions) + initial_force = pm_forces(positions, delta=initial_conditions, token=token, comms=comms) a = jnp.atleast_1d(a) dx = growth_factor(cosmo, a) * initial_force p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force - return dx, p, f + return dx, p, f, comms def linear_field(mesh_shape, box_size, pk, seed): """