diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..d476f32 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/google/yapf + rev: v0.40.2 + hooks: + - id: yapf + args: ['--parallel', '--in-place'] +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) \ No newline at end of file diff --git a/design.md b/design.md index 329270a..a0727a1 100644 --- a/design.md +++ b/design.md @@ -4,14 +4,14 @@ This document aims to detail some of the API, implementation choices, and intern ## Objective -Provide a user-friendly framework for distributed Particle-Mesh N-body simulations. +Provide a user-friendly framework for distributed Particle-Mesh N-body simulations. ## Related Work This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models. - [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow -- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD +- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD - Borg diff --git a/dev/job_pfft.sh b/dev/job_pfft.sh new file mode 100644 index 0000000..a0b73c0 --- /dev/null +++ b/dev/job_pfft.sh @@ -0,0 +1,14 @@ +#!/bin/bash +#SBATCH -A m1727 +#SBATCH -C gpu +#SBATCH -q debug +#SBATCH -t 0:05:00 +#SBATCH -N 2 +#SBATCH --ntasks-per-node=4 +#SBATCH -c 32 +#SBATCH --gpus-per-task=1 +#SBATCH --gpu-bind=none + +module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit +export SLURM_CPU_BIND="cores" +srun python test_pfft.py diff --git a/dev/test_pfft.py b/dev/test_pfft.py new file mode 100644 index 0000000..5a956d8 --- /dev/null +++ b/dev/test_pfft.py @@ -0,0 +1,96 @@ +# Can be executed with: +# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py +from functools import partial + +import jax +import jax.lax as lax +import jax.numpy as jnp +import numpy as np +from jax.experimental.maps import Mesh, xmap +from jax.experimental.pjit import PartitionSpec, pjit + +jax.distributed.initialize() + +cube_size = 2048 + + +@partial(xmap, + in_axes=[...], + out_axes=['x', 'y', ...], + axis_sizes={ + 'x': cube_size, + 'y': cube_size + }, + axis_resources={ + 'x': 'nx', + 'y': 'ny', + 'key_x': 'nx', + 'key_y': 'ny' + }) +def pnormal(key): + return jax.random.normal(key, shape=[cube_size]) + + +@partial(xmap, + in_axes={ + 0: 'x', + 1: 'y' + }, + out_axes=['x', 'y', ...], + axis_resources={ + 'x': 'nx', + 'y': 'ny' + }) +@jax.jit +def pfft3d(mesh): + # [x, y, z] + mesh = jnp.fft.fft(mesh) # Transform on z + mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x] + mesh = jnp.fft.fft(mesh) # Transform on x + mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y] + mesh = jnp.fft.fft(mesh) # Transform on y + # [z, x, y] + return mesh + + +@partial(xmap, + in_axes={ + 0: 'x', + 1: 'y' + }, + out_axes=['x', 'y', ...], + axis_resources={ + 'x': 'nx', + 'y': 'ny' + }) +@jax.jit +def pifft3d(mesh): + # [z, x, y] + mesh = jnp.fft.ifft(mesh) # Transform on y + mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x] + mesh = jnp.fft.ifft(mesh) # Transform on x + mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z] + mesh = jnp.fft.ifft(mesh) # Transform on z + # [x, y, z] + return mesh + + +key = jax.random.PRNGKey(42) +# keys = jax.random.split(key, 4).reshape((2,2,2)) + +# We reshape all our devices to the mesh shape we want +devices = np.array(jax.devices()).reshape((2, 4)) + +with Mesh(devices, ('nx', 'ny')): + mesh = pnormal(key) + kmesh = pfft3d(mesh) + kmesh.block_until_ready() + +# jax.profiler.start_trace("tensorboard") +# with Mesh(devices, ('nx', 'ny')): +# mesh = pnormal(key) +# kmesh = pfft3d(mesh) +# kmesh.block_until_ready() +# jax.profiler.stop_trace() + +print('Done') diff --git a/dev/test_script.py b/dev/test_script.py index a9566c2..4f3ca06 100644 --- a/dev/test_script.py +++ b/dev/test_script.py @@ -1,48 +1,53 @@ # Start this script with: # mpirun -np 4 python test_script.py import os + os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' -import matplotlib.pylab as plt -import jax -import numpy as np -import jax.numpy as jnp +import jax import jax.lax as lax +import jax.numpy as jnp +import matplotlib.pylab as plt +import numpy as np +import tensorflow_probability as tfp from jax.experimental.maps import mesh, xmap from jax.experimental.pjit import PartitionSpec, pjit -import tensorflow_probability as tfp; tfp = tfp.substrates.jax + +tfp = tfp.substrates.jax tfd = tfp.distributions + def cic_paint(mesh, positions): - """ Paints positions onto mesh + """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [npart, 3] """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], - [0., 0, 1], [1., 1, 0], [1., 0, 1], - [0., 1, 1], [1., 1, 1]]]) + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], + [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + + dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), + inserted_window_dims=(0, 1, 2), + scatter_dims_to_operand_dims=(0, 1, + 2)) + mesh = lax.scatter_add( + mesh, + neighboor_coords.reshape([-1, 8, 3]).astype('int32'), + kernel.reshape([-1, 8]), dnums) + return mesh - dnums = jax.lax.ScatterDimensionNumbers( - update_window_dims=(), - inserted_window_dims=(0, 1, 2), - scatter_dims_to_operand_dims=(0, 1, 2)) - mesh = lax.scatter_add(mesh, - neighboor_coords.reshape([-1,8,3]).astype('int32'), - kernel.reshape([-1,8]), - dnums) - return mesh # And let's draw some points from some 3D distribution -dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.) +dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.], + scale_identity_multiplier=3.) pos = dist.sample(1e4, seed=jax.random.PRNGKey(0)) f = pjit(lambda x: cic_paint(x, pos), - in_axis_resources=PartitionSpec('x', 'y', 'z'), + in_axis_resources=PartitionSpec('x', 'y', 'z'), out_axis_resources=None) devices = np.array(jax.devices()).reshape((2, 2, 1)) @@ -51,13 +56,13 @@ devices = np.array(jax.devices()).reshape((2, 2, 1)) m = jnp.zeros([32, 32, 32]) with mesh(devices, ('x', 'y', 'z')): - # Shard the mesh, I'm not sure this is absolutely necessary - m = pjit(lambda x: x, - in_axis_resources=None, - out_axis_resources=PartitionSpec('x', 'y', 'z'))(m) + # Shard the mesh, I'm not sure this is absolutely necessary + m = pjit(lambda x: x, + in_axis_resources=None, + out_axis_resources=PartitionSpec('x', 'y', 'z'))(m) - # Apply the sharded CiC function - res = f(m) + # Apply the sharded CiC function + res = f(m) plt.imshow(res.sum(axis=2)) -plt.show() \ No newline at end of file +plt.show() diff --git a/jaxpm/experimental/__init__.py b/jaxpm/experimental/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jaxpm/experimental/distributed_ops.py b/jaxpm/experimental/distributed_ops.py new file mode 100644 index 0000000..a06b03c --- /dev/null +++ b/jaxpm/experimental/distributed_ops.py @@ -0,0 +1,292 @@ +from functools import partial + +import jax +import jax.lax as lax +import jax.numpy as jnp +import jax_cosmo as jc +from jax.experimental.maps import xmap +from jax.experimental.pjit import PartitionSpec, pjit + +import jaxpm.painting as paint + +# TODO: add a way to configure axis resources from command line +axis_resources = {'x': 'nx', 'y': 'ny'} +mesh_size = {'nx': 2, 'ny': 2} + + +@partial(xmap, + in_axes=({ + 0: 'x', + 2: 'y' + }, { + 0: 'x', + 2: 'y' + }, { + 0: 'x', + 2: 'y' + }), + out_axes=({ + 0: 'x', + 2: 'y' + }), + axis_resources=axis_resources) +def stack3d(a, b, c): + return jnp.stack([a, b, c], axis=-1) + + +@partial(xmap, + in_axes=({ + 0: 'x', + 2: 'y' + }, [...]), + out_axes=({ + 0: 'x', + 2: 'y' + }), + axis_resources=axis_resources) +def scalar_multiply(a, factor): + return a * factor + + +@partial(xmap, + in_axes=({ + 0: 'x', + 2: 'y' + }, { + 0: 'x', + 2: 'y' + }), + out_axes=({ + 0: 'x', + 2: 'y' + }), + axis_resources=axis_resources) +def add(a, b): + return a + b + + +@partial(xmap, + in_axes=['x', 'y', ...], + out_axes=['x', 'y', ...], + axis_resources=axis_resources) +def fft3d(mesh): + """ Performs a 3D complex Fourier transform + + Args: + mesh: a real 3D tensor of shape [Nx, Ny, Nz] + + Returns: + 3D FFT of the input, note that the dimensions of the output + are tranposed. + """ + mesh = jnp.fft.fft(mesh) + mesh = lax.all_to_all(mesh, 'x', 0, 0) + mesh = jnp.fft.fft(mesh) + mesh = lax.all_to_all(mesh, 'y', 0, 0) + return jnp.fft.fft(mesh) # Note the output is transposed # [z, x, y] + + +@partial(xmap, + in_axes=['x', 'y', ...], + out_axes=['x', 'y', ...], + axis_resources=axis_resources) +def ifft3d(mesh): + mesh = jnp.fft.ifft(mesh) + mesh = lax.all_to_all(mesh, 'y', 0, 0) + mesh = jnp.fft.ifft(mesh) + mesh = lax.all_to_all(mesh, 'x', 0, 0) + return jnp.fft.ifft(mesh).real + + +def normal(key, shape=[]): + + @partial(xmap, + in_axes=['x', 'y', ...], + out_axes={ + 0: 'x', + 2: 'y' + }, + axis_resources=axis_resources) + def fn(key): + """ Generate a distributed random normal distributions + Args: + key: array of random keys with same layout as computational mesh + shape: logical shape of array to sample + """ + return jax.random.normal( + key, + shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] + + shape[2:]) + + return fn(key) + + +@partial(xmap, + in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]), + out_axes=['x', 'y', ...], + axis_resources=axis_resources) +@jax.jit +def scale_by_power_spectrum(kfield, kvec, k, pk): + kx, ky, kz = kvec + kk = jnp.sqrt(kx**2 + ky**2 + kz**2) + return kfield * jc.scipy.interpolate.interp(kk, k, pk) + + +@partial(xmap, + in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]), + out_axes=(['x', 'y', 'z']), + axis_resources=axis_resources) +def gradient_laplace_kernel(kfield, kvec): + kx, ky, kz = kvec + kk = (kx**2 + ky**2 + kz**2) + kernel = jnp.where(kk == 0, 1., 1. / kk) + return (kfield * kernel * 1j * 1 / 6.0 * + (8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 / + 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j * + 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))) + + +@partial(xmap, + in_axes=([...]), + out_axes={ + 0: 'x', + 2: 'y' + }, + axis_sizes={ + 'x': mesh_size['nx'], + 'y': mesh_size['ny'] + }, + axis_resources=axis_resources) +def meshgrid(x, y, z): + """ Generates a mesh grid of appropriate size for the + computational mesh we have. + """ + return jnp.stack(jnp.meshgrid(x, y, z), axis=-1) + + +def cic_paint(pos, mesh_shape, halo_size=0): + + @partial(xmap, + in_axes=({ + 0: 'x', + 2: 'y' + }), + out_axes=({ + 0: 'x', + 2: 'y' + }), + axis_resources=axis_resources) + def fn(pos): + + mesh = jnp.zeros([ + mesh_shape[0] // mesh_size['nx'] + + 2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size + ] + mesh_shape[2:]) + + # Paint particles + mesh = paint.cic_paint( + mesh, + pos.reshape(-1, 3) + + jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) + + # Perform halo exchange + # Halo exchange along x + left = lax.pshuffle(mesh[-2 * halo_size:], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + right = lax.pshuffle(mesh[:2 * halo_size], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + mesh = mesh.at[:2 * halo_size].add(left) + mesh = mesh.at[-2 * halo_size:].add(right) + + # Halo exchange along y + left = lax.pshuffle(mesh[:, -2 * halo_size:], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + right = lax.pshuffle(mesh[:, :2 * halo_size], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + mesh = mesh.at[:, :2 * halo_size].add(left) + mesh = mesh.at[:, -2 * halo_size:].add(right) + + # removing halo and returning mesh + return mesh[halo_size:-halo_size, halo_size:-halo_size] + + return fn(pos) + + +def cic_read(mesh, pos, halo_size=0): + + @partial(xmap, + in_axes=( + { + 0: 'x', + 2: 'y' + }, + { + 0: 'x', + 2: 'y' + }, + ), + out_axes=({ + 0: 'x', + 2: 'y' + }), + axis_resources=axis_resources) + def fn(mesh, pos): + + # Halo exchange to grab neighboring borders + # Exchange along x + left = lax.pshuffle(mesh[-halo_size:], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + right = lax.pshuffle(mesh[:halo_size], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + mesh = jnp.concatenate([left, mesh, right], axis=0) + # Exchange along y + left = lax.pshuffle(mesh[:, -halo_size:], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + right = lax.pshuffle(mesh[:, :halo_size], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + mesh = jnp.concatenate([left, mesh, right], axis=1) + + # Reading field at particles positions + res = paint.cic_read( + mesh, + pos.reshape(-1, 3) + + jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) + + return res.reshape(pos.shape[:-1]) + + return fn(mesh, pos) + + +@partial(pjit, + in_axis_resources=PartitionSpec('nx', 'ny'), + out_axis_resources=PartitionSpec('nx', None, 'ny', None)) +def reshape_dense_to_split(x): + """ Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z] + Changes the logical shape of the array, but no shuffling of the + data should be necessary + """ + shape = list(x.shape) + return x.reshape([ + mesh_size['nx'], shape[0] // + mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny'] + ] + shape[2:]) + + +@partial(pjit, + in_axis_resources=PartitionSpec('nx', None, 'ny', None), + out_axis_resources=PartitionSpec('nx', 'ny')) +def reshape_split_to_dense(x): + """ Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z] + Changes the logical shape of the array, but no shuffling of the + data should be necessary + """ + shape = list(x.shape) + return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:]) diff --git a/jaxpm/experimental/distributed_pm.py b/jaxpm/experimental/distributed_pm.py new file mode 100644 index 0000000..b633cf7 --- /dev/null +++ b/jaxpm/experimental/distributed_pm.py @@ -0,0 +1,100 @@ +from functools import partial + +import jax +import jax.numpy as jnp +import jax_cosmo as jc +from jax.experimental.maps import xmap +from jax.lax import linear_solve_p + +import jaxpm.experimental.distributed_ops as dops +from jaxpm.growth import dGfa, growth_factor, growth_rate +from jaxpm.kernels import fftk + + +def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16): + """ + Computes gravitational forces on particles using a PM scheme + """ + if mesh_shape is None: + mesh_shape = delta_k.shape + kvec = [k.squeeze() for k in fftk(mesh_shape, symmetric=False)] + + if delta_k is None: + delta = dops.cic_paint(positions, mesh_shape, halo_size) + delta_k = dops.fft3d(dops.reshape_split_to_dense(delta)) + + forces_k = dops.gradient_laplace_kernel(delta_k, kvec) + + # Recovers forces at particle positions + forces = [ + dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions, + halo_size) for f in forces_k + ] + + return dops.stack3d(*forces) + + +def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True): + """ + Generate initial conditions. + Seed should have the dimension of the computational mesh + """ + + # Sample normal field + field = dops.normal(seed, shape=mesh_shape) + + # Go to Fourier space + field = dops.fft3d(dops.reshape_split_to_dense(field)) + + # Rescaling k to physical units + kvec = [ + k.squeeze() / box_size[i] * mesh_shape[i] + for i, k in enumerate(fftk(mesh_shape, symmetric=False)) + ] + k = jnp.logspace(-4, 2, 256) + pk = jc.power.linear_matter_power(cosmo, k) + pk = pk * (mesh_shape[0] * mesh_shape[1] * + mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2]) + + field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk)) + + if return_Fourier: + return field + else: + return dops.reshape_dense_to_split(dops.ifft3d(field)) + + +def lpt(cosmo, initial_conditions, positions, a): + """ + Computes first order LPT displacement + """ + initial_force = pm_forces(positions, delta_k=initial_conditions) + a = jnp.atleast_1d(a) + dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a)) + p = dops.scalar_multiply( + dx, + a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a))) + return dx, p + + +def make_ode_fn(mesh_shape): + + def nbody_ode(state, a, cosmo): + """ + state is a tuple (position, velocities) + """ + pos, vel = state + + forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m + + # Computes the update of position (drift) + dpos = dops.scalar_multiply( + vel, 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a)))) + + # Computes the update of velocity (kick) + dvel = dops.scalar_multiply( + forces, 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)))) + + return dpos, dvel + + return nbody_ode diff --git a/jaxpm/growth.py b/jaxpm/growth.py index 0be4718..5b6908c 100644 --- a/jaxpm/growth.py +++ b/jaxpm/growth.py @@ -1,8 +1,8 @@ import jax.numpy as np - +from jax_cosmo.background import * from jax_cosmo.scipy.interpolate import interp from jax_cosmo.scipy.ode import odeint -from jax_cosmo.background import * + def E(cosmo, a): r"""Scale factor dependent factor E(a) in the Hubble @@ -52,12 +52,8 @@ def df_de(cosmo, a, epsilon=1e-5): \frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)- \frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)} """ - return ( - 3 - * cosmo.wa - * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) - / np.power(np.log(a - epsilon), 2) - ) + return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) / + np.power(np.log(a - epsilon), 2)) def dEa(cosmo, a): @@ -89,15 +85,11 @@ def dEa(cosmo, a): where :math:`f(a)` is the Dark Energy evolution parameter computed by :py:meth:`.f_de`. """ - return ( - 0.5 - * ( - -3 * cosmo.Omega_m * np.power(a, -4) - - 2 * cosmo.Omega_k * np.power(a, -3) - + df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a)) - ) - / np.power(Esqr(cosmo, a), 0.5) - ) + return (0.5 * + (-3 * cosmo.Omega_m * np.power(a, -4) - + 2 * cosmo.Omega_k * np.power(a, -3) + + df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) / + np.power(Esqr(cosmo, a), 0.5)) def growth_factor(cosmo, a): @@ -155,8 +147,7 @@ def growth_factor_second(cosmo, a): """ if cosmo._flags["gamma_growth"]: raise NotImplementedError( - "Gamma growth rate is not implemented for second order growth!" - ) + "Gamma growth rate is not implemented for second order growth!") return None else: return _growth_factor_second_ODE(cosmo, a) @@ -228,8 +219,7 @@ def growth_rate_second(cosmo, a): """ if cosmo._flags["gamma_growth"]: raise NotImplementedError( - "Gamma growth factor is not implemented for second order growth!" - ) + "Gamma growth factor is not implemented for second order growth!") return None else: return _growth_rate_second_ODE(cosmo, a) @@ -258,23 +248,19 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): atab = np.logspace(log10_amin, 0.0, steps) def D_derivs(y, x): - q = ( - 2.0 - - 0.5 - * ( - Omega_m_a(cosmo, x) - + (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x) - ) - ) / x + q = (2.0 - 0.5 * + (Omega_m_a(cosmo, x) + + (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x r = 1.5 * Omega_m_a(cosmo, x) / x / x g1, g2 = y[0] f1, f2 = y[1] dy1da = [f1, -q * f1 + r * g1] - dy2da = [f2, -q * f2 + r * g2 - r * g1 ** 2] + dy2da = [f2, -q * f2 + r * g2 - r * g1**2] return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]]) - y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]]) + y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2], + [1.0, -6.0 / 7 * atab[0]]]) y = odeint(D_derivs, y0, atab) # compute second order derivatives growth @@ -473,8 +459,7 @@ def _growth_rate_gamma(cosmo, a): see :cite:`2019:Euclid Preparation VII, eqn.32` """ - return Omega_m_a(cosmo, a) ** cosmo.gamma - + return Omega_m_a(cosmo, a)**cosmo.gamma def Gf(cosmo, a): @@ -503,7 +488,7 @@ def Gf(cosmo, a): """ f1 = growth_rate(cosmo, a) g1 = growth_factor(cosmo, a) - D1f = f1*g1/ a + D1f = f1 * g1 / a return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) @@ -532,7 +517,7 @@ def Gf2(cosmo, a): """ f2 = growth_rate_second(cosmo, a) g2 = growth_factor_second(cosmo, a) - D2f = f2*g2/ a + D2f = f2 * g2 / a return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) @@ -563,13 +548,12 @@ def dGfa(cosmo, a): """ f1 = growth_rate(cosmo, a) g1 = growth_factor(cosmo, a) - D1f = f1*g1/ a + D1f = f1 * g1 / a cache = cosmo._workspace['background.growth_factor'] f1p = cache['h'] / cache['a'] * cache['g'] f1p = interp(np.log(a), np.log(cache['a']), f1p) Ea = E(cosmo, a) - return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + - 3 * a**2 * Ea * D1f) + return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f) def dGf2a(cosmo, a): @@ -599,10 +583,9 @@ def dGf2a(cosmo, a): """ f2 = growth_rate_second(cosmo, a) g2 = growth_factor_second(cosmo, a) - D2f = f2*g2/ a + D2f = f2 * g2 / a cache = cosmo._workspace['background.growth_factor'] f2p = cache['h2'] / cache['a'] * cache['g2'] f2p = interp(np.log(a), np.log(cache['a']), f2p) E = E(cosmo, a) - return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + - 3 * a**2 * E * D2f) \ No newline at end of file + return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 97d34dd..8447f8a 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,25 +1,27 @@ -import numpy as np import jax.numpy as jnp +import numpy as np + def fftk(shape, symmetric=True, finite=False, dtype=np.float32): - """ Return k_vector given a shape (nc, nc, nc) and box_size + """ Return k_vector given a shape (nc, nc, nc) and box_size """ - k = [] - for d in range(len(shape)): - kd = np.fft.fftfreq(shape[d]) - kd *= 2 * np.pi - kdshape = np.ones(len(shape), dtype='int') - if symmetric and d == len(shape) - 1: - kd = kd[:shape[d] // 2 + 1] - kdshape[d] = len(kd) - kd = kd.reshape(kdshape) + k = [] + for d in range(len(shape)): + kd = np.fft.fftfreq(shape[d]) + kd *= 2 * np.pi + kdshape = np.ones(len(shape), dtype='int') + if symmetric and d == len(shape) - 1: + kd = kd[:shape[d] // 2 + 1] + kdshape[d] = len(kd) + kd = kd.reshape(kdshape) + + k.append(kd.astype(dtype)) + del kd, kdshape + return k - k.append(kd.astype(dtype)) - del kd, kdshape - return k def gradient_kernel(kvec, direction, order=1): - """ + """ Computes the gradient kernel in the requested direction Parameters: ----------- @@ -32,20 +34,21 @@ def gradient_kernel(kvec, direction, order=1): wts: array Complex kernel """ - if order == 0: - wts = 1j * kvec[direction] - wts = jnp.squeeze(wts) - wts[len(wts) // 2] = 0 - wts = wts.reshape(kvec[direction].shape) - return wts - else: - w = kvec[direction] - a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) - wts = a * 1j - return wts + if order == 0: + wts = 1j * kvec[direction] + wts = jnp.squeeze(wts) + wts[len(wts) // 2] = 0 + wts = wts.reshape(kvec[direction].shape) + return wts + else: + w = kvec[direction] + a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) + wts = a * 1j + return wts + def laplace_kernel(kvec): - """ + """ Compute the Laplace kernel from a given K vector Parameters: ----------- @@ -56,16 +59,17 @@ def laplace_kernel(kvec): wts: array Complex kernel """ - kk = sum(ki**2 for ki in kvec) - mask = (kk == 0).nonzero() - kk[mask] = 1 - wts = 1. / kk - imask = (~(kk == 0)).astype(int) - wts *= imask - return wts + kk = sum(ki**2 for ki in kvec) + mask = (kk == 0).nonzero() + kk[mask] = 1 + wts = 1. / kk + imask = (~(kk == 0)).astype(int) + wts *= imask + return wts + def longrange_kernel(kvec, r_split): - """ + """ Computes a long range kernel Parameters: ----------- @@ -78,29 +82,31 @@ def longrange_kernel(kvec, r_split): wts: array kernel """ - if r_split != 0: - kk = sum(ki**2 for ki in kvec) - return np.exp(-kk * r_split**2) - else: - return 1. + if r_split != 0: + kk = sum(ki**2 for ki in kvec) + return np.exp(-kk * r_split**2) + else: + return 1. + def cic_compensation(kvec): - """ + """ Computes cic compensation kernel. Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499 Itself based on equation 18 (with p=2) of `Jing et al 2005 `_ Args: - kvec: array of k values in Fourier space + kvec: array of k values in Fourier space Returns: v: array of kernel """ - kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] - wts = (kwts[0] * kwts[1] * kwts[2])**(-2) - return wts + kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] + wts = (kwts[0] * kwts[1] * kwts[2])**(-2) + return wts + def PGD_kernel(kvec, kl, ks): - """ + """ Computes the PGD kernel Parameters: ----------- @@ -115,12 +121,12 @@ def PGD_kernel(kvec, kl, ks): v: array kernel """ - kk = sum(ki**2 for ki in kvec) - kl2 = kl**2 - ks4 = ks**4 - mask = (kk == 0).nonzero() - kk[mask] = 1 - v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4) - imask = (~(kk == 0)).astype(int) - v *= imask - return v \ No newline at end of file + kk = sum(ki**2 for ki in kvec) + kl2 = kl**2 + ks4 = ks**4 + mask = (kk == 0).nonzero() + kk[mask] = 1 + v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4) + imask = (~(kk == 0)).astype(int) + v *= imask + return v diff --git a/jaxpm/lensing.py b/jaxpm/lensing.py index b4beeef..0143adc 100644 --- a/jaxpm/lensing.py +++ b/jaxpm/lensing.py @@ -1,11 +1,12 @@ -import jax +import jax import jax.numpy as jnp -import jax_cosmo.constants as constants import jax_cosmo - +import jax_cosmo.constants as constants from jax.scipy.ndimage import map_coordinates -from jaxpm.utils import gaussian_smoothing + from jaxpm.painting import cic_paint_2d +from jaxpm.utils import gaussian_smoothing + def density_plane(positions, box_shape, @@ -26,9 +27,11 @@ def density_plane(positions, xy = xy / nx * plane_resolution # Selecting only particles that fall inside the volume of interest - weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.) + weight = jnp.where( + (d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.) # Painting density plane - density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight) + density_plane = cic_paint_2d( + jnp.zeros([plane_resolution, plane_resolution]), xy, weight) # Apply density normalization density_plane = density_plane / ((nx / plane_resolution) * @@ -36,45 +39,44 @@ def density_plane(positions, # Apply Gaussian smoothing if requested if smoothing_sigma is not None: - density_plane = gaussian_smoothing(density_plane, - smoothing_sigma) + density_plane = gaussian_smoothing(density_plane, smoothing_sigma) return density_plane -def convergence_Born(cosmo, - density_planes, - coords, - z_source): - """ +def convergence_Born(cosmo, density_planes, coords, z_source): + """ Compute the Born convergence Args: cosmo: `Cosmology`, cosmology object. - density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use + density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2]. z_source: 1-D `Tensor` of source redshifts with shape [Nz] . name: `string`, name of the operation. Returns: `Tensor` of shape [batch_size, N, Nz], of convergence values. """ - # Compute constant prefactor: - constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2 - # Compute comoving distance of source galaxies - r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source)) + # Compute constant prefactor: + constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2 + # Compute comoving distance of source galaxies + r_s = jax_cosmo.background.radial_comoving_distance( + cosmo, 1 / (1 + z_source)) - convergence = 0 - for entry in density_planes: - r = entry['r']; a = entry['a']; p = entry['plane'] - dx = entry['dx']; dz = entry['dz'] - # Normalize density planes - density_normalization = dz * r / a - p = (p - p.mean()) * constant_factor * density_normalization + convergence = 0 + for entry in density_planes: + r = entry['r'] + a = entry['a'] + p = entry['plane'] + dx = entry['dx'] + dz = entry['dz'] + # Normalize density planes + density_normalization = dz * r / a + p = (p - p.mean()) * constant_factor * density_normalization - # Interpolate at the density plane coordinates - im = map_coordinates(p, - coords * r / dx - 0.5, - order=1, mode="wrap") + # Interpolate at the density plane coordinates + im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap") - convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1]) + convergence += im * jnp.clip(1. - + (r / r_s), 0, 1000).reshape([-1, 1, 1]) - return convergence + return convergence diff --git a/jaxpm/nn.py b/jaxpm/nn.py index 933ea53..d8f27be 100644 --- a/jaxpm/nn.py +++ b/jaxpm/nn.py @@ -1,6 +1,7 @@ +import haiku as hk import jax import jax.numpy as jnp -import haiku as hk + def _deBoorVectorized(x, t, c, p): """ @@ -13,48 +14,47 @@ def _deBoorVectorized(x, t, c, p): c: array of control points p: degree of B-spline """ - k = jnp.digitize(x, t) -1 - - d = [c[j + k - p] for j in range(0, p+1)] - for r in range(1, p+1): - for j in range(p, r-1, -1): - alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p]) - d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j] + k = jnp.digitize(x, t) - 1 + + d = [c[j + k - p] for j in range(0, p + 1)] + for r in range(1, p + 1): + for j in range(p, r - 1, -1): + alpha = (x - t[j + k - p]) / (t[j + 1 + k - r] - t[j + k - p]) + d[j] = (1.0 - alpha) * d[j - 1] + alpha * d[j] return d[p] class NeuralSplineFourierFilter(hk.Module): - """A rotationally invariant filter parameterized by + """A rotationally invariant filter parameterized by a b-spline with parameters specified by a small NN.""" - def __init__(self, n_knots=8, latent_size=16, name=None): + def __init__(self, n_knots=8, latent_size=16, name=None): + """ + n_knots: number of control points for the spline """ - n_knots: number of control points for the spline - """ - super().__init__(name=name) - self.n_knots = n_knots - self.latent_size = latent_size + super().__init__(name=name) + self.n_knots = n_knots + self.latent_size = latent_size - def __call__(self, x, a): - """ + def __call__(self, x, a): + """ x: array, scale, normalized to fftfreq default a: scalar, scale factor """ - net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a))) - net = jnp.sin(hk.Linear(self.latent_size)(net)) + net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a))) + net = jnp.sin(hk.Linear(self.latent_size)(net)) - w = hk.Linear(self.n_knots+1)(net) - k = hk.Linear(self.n_knots-1)(net) - - # make sure the knots sum to 1 and are in the interval 0,1 - k = jnp.concatenate([jnp.zeros((1,)), - jnp.cumsum(jax.nn.softmax(k))]) + w = hk.Linear(self.n_knots + 1)(net) + k = hk.Linear(self.n_knots - 1)(net) - w = jnp.concatenate([jnp.zeros((1,)), - w]) + # make sure the knots sum to 1 and are in the interval 0,1 + k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))]) - # Augment with repeating points - ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))]) + w = jnp.concatenate([jnp.zeros((1, )), w]) - return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3) \ No newline at end of file + # Augment with repeating points + ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))]) + + return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w, + 3) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 4237c23..fb5dbd5 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -1,96 +1,100 @@ import jax -import jax.numpy as jnp import jax.lax as lax +import jax.numpy as jnp -from jaxpm.kernels import fftk, cic_compensation +from jaxpm.kernels import cic_compensation, fftk -def cic_paint(mesh, positions): - """ Paints positions onto mesh + +def cic_paint(mesh, positions, weight=None): + """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [npart, 3] """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], - [0., 0, 1], [1., 1, 0], [1., 0, 1], - [0., 1, 1], [1., 1, 1]]]) + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], + [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + if weight is not None: + kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) - neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape)) + neighboor_coords = jnp.mod( + neighboor_coords.reshape([-1, 8, 3]).astype('int32'), + jnp.array(mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), + inserted_window_dims=(0, 1, 2), + scatter_dims_to_operand_dims=(0, 1, + 2)) + mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]), + dnums) + return mesh - dnums = jax.lax.ScatterDimensionNumbers( - update_window_dims=(), - inserted_window_dims=(0, 1, 2), - scatter_dims_to_operand_dims=(0, 1, 2)) - mesh = lax.scatter_add(mesh, - neighboor_coords, - kernel.reshape([-1,8]), - dnums) - return mesh def cic_read(mesh, positions): - """ Paints positions onto mesh + """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [npart, 3] - """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], - [0., 0, 1], [1., 1, 0], [1., 0, 1], - [0., 1, 1], [1., 1, 1]]]) + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], + [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape)) + neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), + jnp.array(mesh.shape)) + + return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1], + neighboor_coords[..., 3]] * kernel).sum(axis=-1) - return (mesh[neighboor_coords[...,0], - neighboor_coords[...,1], - neighboor_coords[...,3]]*kernel).sum(axis=-1) def cic_paint_2d(mesh, positions, weight): - """ Paints positions onto a 2d mesh + """ Paints positions onto a 2d mesh mesh: [nx, ny] positions: [npart, 2] weight: [npart] """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] - if weight is not None: - kernel = kernel * weight[...,jnp.newaxis] - - neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape)) + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] + if weight is not None: + kernel = kernel * weight[..., jnp.newaxis] + + neighboor_coords = jnp.mod( + neighboor_coords.reshape([-1, 4, 2]).astype('int32'), + jnp.array(mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), + inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, + 1)) + mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]), + dnums) + return mesh - dnums = jax.lax.ScatterDimensionNumbers( - update_window_dims=(), - inserted_window_dims=(0, 1), - scatter_dims_to_operand_dims=(0, 1)) - mesh = lax.scatter_add(mesh, - neighboor_coords, - kernel.reshape([-1,4]), - dnums) - return mesh def compensate_cic(field): - """ + """ Compensate for CiC painting Args: field: input 3D cic-painted field Returns: compensated_field """ - nc = field.shape - kvec = fftk(nc) + nc = field.shape + kvec = fftk(nc) - delta_k = jnp.fft.rfftn(field) - delta_k = cic_compensation(kvec) * delta_k - return jnp.fft.irfftn(delta_k) \ No newline at end of file + delta_k = jnp.fft.rfftn(field) + delta_k = cic_compensation(kvec) * delta_k + return jnp.fft.irfftn(delta_k) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 8e9e052..686cc07 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -1,11 +1,12 @@ import jax import jax.numpy as jnp - import jax_cosmo as jc -from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel, PGD_kernel +from jaxpm.growth import dGfa, growth_factor, growth_rate +from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel, + longrange_kernel) from jaxpm.painting import cic_paint, cic_read -from jaxpm.growth import growth_factor, growth_rate, dGfa + def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): """ @@ -21,10 +22,14 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): delta_k = jnp.fft.rfftn(delta) # Computes gravitational potential - pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) + pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, + r_split=r_split) # Computes gravitational forces - return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions) - for i in range(3)],axis=-1) + return jnp.stack([ + cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions) + for i in range(3) + ], + axis=-1) def lpt(cosmo, initial_conditions, positions, a): @@ -34,25 +39,31 @@ def lpt(cosmo, initial_conditions, positions, a): initial_force = pm_forces(positions, delta=initial_conditions) a = jnp.atleast_1d(a) dx = growth_factor(cosmo, a) * initial_force - p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx - f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force + p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, + a)) * dx + f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, + a) * initial_force return dx, p, f + def linear_field(mesh_shape, box_size, pk, seed): """ Generate initial conditions. """ kvec = fftk(mesh_shape) - kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5 - pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2]) + kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 + for i, kk in enumerate(kvec))**0.5 + pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / ( + box_size[0] * box_size[1] * box_size[2]) field = jax.random.normal(seed, mesh_shape) field = jnp.fft.rfftn(field) * pkmesh**0.5 field = jnp.fft.irfftn(field) return field + def make_ode_fn(mesh_shape): - + def nbody_ode(state, a, cosmo): """ state is a tuple (position, velocities) @@ -63,10 +74,10 @@ def make_ode_fn(mesh_shape): # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel - + # Computes the update of velocity (kick) dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces - + return dpos, dvel return nbody_ode @@ -128,4 +139,3 @@ def make_neural_ode_fn(model, mesh_shape): return dpos, dvel return neural_nbody_ode - diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 0249174..1593ba0 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -1,85 +1,87 @@ -import numpy as np import jax.numpy as jnp +import numpy as np from jax.scipy.stats import norm __all__ = ['power_spectrum'] + def _initialize_pk(shape, boxsize, kmin, dk): - """ + """ Helper function to initialize various (fixed) values for powerspectra... not differentiable! """ - I = np.eye(len(shape), dtype='int') * -2 + 1 + I = np.eye(len(shape), dtype='int') * -2 + 1 - W = np.empty(shape, dtype='f4') - W[...] = 2.0 - W[..., 0] = 1.0 - W[..., -1] = 1.0 + W = np.empty(shape, dtype='f4') + W[...] = 2.0 + W[..., 0] = 1.0 + W[..., -1] = 1.0 - kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2 - kedges = np.arange(kmin, kmax, dk) + kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2 + kedges = np.arange(kmin, kmax, dk) - k = [ - np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape) - for N, L, kshape, pkshape in zip(shape, boxsize, I, shape) - ] - kmag = sum(ki**2 for ki in k)**0.5 + k = [ + np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape) + for N, L, kshape, pkshape in zip(shape, boxsize, I, shape) + ] + kmag = sum(ki**2 for ki in k)**0.5 - xsum = np.zeros(len(kedges) + 1) - Nsum = np.zeros(len(kedges) + 1) + xsum = np.zeros(len(kedges) + 1) + Nsum = np.zeros(len(kedges) + 1) - dig = np.digitize(kmag.flat, kedges) + dig = np.digitize(kmag.flat, kedges) - xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size) - Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size) - return dig, Nsum, xsum, W, k, kedges + xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size) + Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size) + return dig, Nsum, xsum, W, k, kedges def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): - """ + """ Calculate the powerspectra given real space field - + Args: - - field: real valued field + + field: real valued field kmin: minimum k-value for binned powerspectra dk: differential in each kbin boxsize: length of each boxlength (can be strangly shaped?) - + Returns: - + kbins: the central value of the bins for plotting power: real valued array of power in each bin - + """ - shape = field.shape - nx, ny, nz = shape + shape = field.shape + nx, ny, nz = shape - #initialze values related to powerspectra (mode bins and weights) - dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) + #initialze values related to powerspectra (mode bins and weights) + dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) - #fast fourier transform - fft_image = jnp.fft.fftn(field) + #fast fourier transform + fft_image = jnp.fft.fftn(field) - #absolute value of fast fourier transform - pk = jnp.real(fft_image * jnp.conj(fft_image)) + #absolute value of fast fourier transform + pk = jnp.real(fft_image * jnp.conj(fft_image)) + #calculating powerspectra + real = jnp.real(pk).reshape([-1]) + imag = jnp.imag(pk).reshape([-1]) - #calculating powerspectra - real = jnp.real(pk).reshape([-1]) - imag = jnp.imag(pk).reshape([-1]) + Psum = jnp.bincount(dig, weights=(W.flatten() * imag), + length=xsum.size) * 1j + Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) - Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j - Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) + P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') - P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') + #normalization for powerspectra + norm = np.prod(np.array(shape[:])).astype('float32')**2 - #normalization for powerspectra - norm = np.prod(np.array(shape[:])).astype('float32')**2 + #find central values of each bin + kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 - #find central values of each bin - kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 + return kbins, P / norm - return kbins, P / norm def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False): """ @@ -131,18 +133,17 @@ def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=Fals def gaussian_smoothing(im, sigma): - """ + """ im: 2d image - sigma: smoothing scale in px + sigma: smoothing scale in px """ - # Compute k vector - kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]), - jnp.fft.fftfreq(im.shape[1])), - axis=-1) - k = jnp.linalg.norm(kvec, axis=-1) - # We compute the value of the filter at frequency k - filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma)) - filter /= filter[0,0] - - return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real + # Compute k vector + kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]), + jnp.fft.fftfreq(im.shape[1])), + axis=-1) + k = jnp.linalg.norm(kvec, axis=-1) + # We compute the value of the filter at frequency k + filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma)) + filter /= filter[0, 0] + return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real diff --git a/setup.py b/setup.py index 44be5a1..a58759a 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( name='JaxPM', @@ -6,6 +6,6 @@ setup( url='https://github.com/DifferentiableUniverseInitiative/JaxPM', author='JaxPM developers', description='A dead simple FastPM implementation in JAX', - packages=find_packages(), + packages=find_packages(), install_requires=['jax', 'jax_cosmo'], -) \ No newline at end of file +)