mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 18:10:55 +00:00
1.4 MiB
1.4 MiB
In [1]:
%pylab inline
import jax
import jax.numpy as jnp
import jax.lax as lax
from jax.experimental.maps import xmap
from jax.experimental.maps import Mesh
from jax.experimental.pjit import PartitionSpec, pjit
from jaxpm.pm import cic_paint, cic_read, fftk
from functools import partial
import jax_cosmo as jc
In [115]:
nc=512
boxsize=1024. # Mpx/h
halo_size=16
In [151]:
def cic_paint(mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), mesh.shape[-1])
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,8]),
dnums)
return mesh
def cic_read(mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
#neighboor_coords = neighboor_coords.reshape([-1,8,3]).astype('int32')
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), mesh.shape[-1])
return (mesh[neighboor_coords[...,0],
neighboor_coords[...,1],
neighboor_coords[...,3]]*kernel).sum(axis=-1)
In [152]:
# Defining the main operations
@partial(xmap,
in_axes={0:'x', 1:'y'},
out_axes=['x','y',...],
axis_sizes={'x':nc, 'y':nc},
axis_resources={'x': 'nx', 'y':'ny',
'key_x':'nx', 'key_y':'ny'})
def pnormal(key):
return jax.random.normal(key, shape=[nc])
@partial(xmap,
in_axes={0:'x', 1:'y'},
out_axes=['x','y',...],
axis_resources={'x': 'nx', 'y': 'ny'})
def pfft3d(mesh):
# [x, y, z]
mesh = jnp.fft.fft(mesh)
mesh = lax.all_to_all(mesh, 'x', 0, 0) # [z, y, x]
mesh = jnp.fft.fft(mesh)
mesh = lax.all_to_all(mesh, 'y', 0, 0) # [z, x, y]
return jnp.fft.fft(mesh)
@partial(xmap,
in_axes={0:'x', 1:'y'},
out_axes=['x','y',...],
axis_resources={'x': 'nx', 'y': 'ny'})
def pifft3d(mesh):
mesh = jnp.fft.ifft(mesh)
mesh = lax.all_to_all(mesh, 'y', 0, 0)
mesh = jnp.fft.ifft(mesh)
mesh = lax.all_to_all(mesh, 'x', 0, 0)
return jnp.fft.ifft(mesh).real
In [153]:
@partial(xmap,
in_axes=(['x','y',...],
['x'],
['y'],
[...],[...],[...]),
out_axes=['x','y',...],
axis_resources={'x': 'nx', 'y': 'ny'})
def cwise_fn(kfield, kx, ky, kz, k, pk):
kk = jnp.sqrt((kx / boxsize * nc)**2 + (ky / boxsize * nc)**2 +
(kz / boxsize * nc)**2)
pkmesh = jc.scipy.interpolate.interp(kk, k, pk)
return kfield*(pkmesh*nc**3/boxsize**3)**0.5
def get_initial_cond(cosmo, seed):
# Get real density field
linear = pnormal(jax.random.split(seed, nc*nc).reshape(nc,nc,-1))
lineark = pfft3d(linear)
k = jnp.logspace(-4, 2, 256)
pk = jc.power.linear_matter_power(cosmo, k)
kvec = fftk([nc,nc,nc], symmetric=False)
lineark = cwise_fn(lineark, kvec[0].squeeze(),kvec[1].squeeze(),kvec[2].squeeze(), k, pk)
return pifft3d(lineark)
In [154]:
key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape([2,2,2])
# We reshape all our devices to the mesh shape we want
devices = np.array(jax.local_devices()).reshape((2, 2))
cosmo = jc.Planck15()
Ok, cool, now let's implement LPT
In [155]:
@partial(xmap,
in_axes=([...]),
out_axes={0:'sx', 2:'sy'},
axis_sizes={'sx':2, 'sy':2},
axis_resources={'sx': 'nx', 'sy': 'ny'})
def pmeshgrid(x, y, z):
return jnp.stack(jnp.meshgrid(x,y,z),axis=-1)
@partial(xmap,
in_axes=(['x','y','z'],
['x'], ['y'], ['z']),
out_axes=(['x','y','z'],
['x','y','z'],
['x','y','z']),
axis_resources={'x': 'nx', 'y': 'ny'})
def papply_gradient_laplace(kfield, kx, ky, kz):
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk)
return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
preshape = pjit(lambda x: x.reshape([2, nc//2, 2, nc//2]+list(x.shape[2:])),
in_axis_resources=PartitionSpec('nx','ny'),
out_axis_resources=PartitionSpec('nx', None, 'ny', None))
pireshape = pjit(lambda x: x.reshape([nc, nc]+list(x.shape[4:])),
in_axis_resources=PartitionSpec('nx', None, 'ny', None),
out_axis_resources=PartitionSpec('nx','ny'))
pcic_read = xmap(lambda mesh, pos: cic_read(mesh, pos.reshape(-1,3)).reshape(pos.shape[:-1]),
in_axes=({0:'sx',2:'sy'},
{0:'sx',2:'sy'}),
out_axes=({0:'sx',2:'sy'}),
axis_resources={'sx': 'nx', 'sy': 'ny'})
pcic_paint = xmap(lambda mesh, pos, halo_size=halo_size: cic_paint(mesh, pos.reshape(-1,3)+jnp.array([halo_size,halo_size,0]).reshape([-1,3])),
in_axes=({0:'sx',2:'sy'},
{0:'sx',2:'sy'}),
out_axes=({0:'sx',2:'sy'}),
axis_resources={'sx': 'nx', 'sy': 'ny'})
@partial(xmap,
in_axes=({0:'sx',2:'sy'},[...]),
out_axes={0:'sx',2:'sy'},
axis_resources={'sx': 'nx', 'sy': 'ny'})
def pad_mesh(mesh, halo_size=halo_size):
return jnp.pad(messh,[halo_size]*3)
@partial(xmap,
in_axes=({0:'sx',2:'sy'}),
out_axes={0:'sx',2:'sy'},
axis_resources={'sx': 'nx', 'sy': 'ny'})
def halo_reduce(mesh, halo_size=halo_size):
for axis_ind, axis_name in enumerate(['sx', 'sy']):
# Split the array
left_margin, center, right_margin = mesh.split([2*halo_size, nc//2 ], axis_ind)
# Perform halo exchange
left = lax.pshuffle(right_margin, perm=[1,0], axis_name=axis_name)
right =lax.pshuffle(left_margin, perm=[1,0], axis_name=axis_name)
if axis_ind==0:
mesh = mesh.at[:2*halo_size].add(left)
mesh = mesh.at[-2*halo_size:].add(right)
else:
mesh = mesh.at[:,:2*halo_size].add(left)
mesh = mesh.at[:,-2*halo_size:].add(right)
# removing leftovers
return mesh[halo_size:-halo_size,halo_size:-halo_size]
In [171]:
with Mesh(devices, ('nx', 'ny')):
initial_conditions = get_initial_cond(cosmo, key)
# Create the particles
pos = pmeshgrid(jnp.arange(nc//2), jnp.arange(nc//2), jnp.arange(nc))
# Take the FFT of the field
lineark = pfft3d(initial_conditions)
# Apply the laplace kernel
kvec = fftk([nc,nc,nc], symmetric=False)
kforces = papply_gradient_laplace(lineark,
kvec[0].squeeze(),
kvec[1].squeeze(),
kvec[2].squeeze())
# Inverse Fourier Transform
forces_x = pcic_read(preshape(pifft3d(kforces[0])), pos)
forces_y = pcic_read(preshape(pifft3d(kforces[1])), pos)
forces_z = pcic_read(preshape(pifft3d(kforces[2])), pos)
# Read the forces at particle positions
dx = xmap(lambda a,b,c: jnp.array([a,b,c]),
in_axes=(['sx','tx','sy','ty','z',...],
['sx','tx','sy','ty','z',...],
['sx','tx','sy','ty','z',...]),
out_axes=['sx','tx','sy','ty','z',...],
axis_resources={'sx': 'nx', 'sy': 'ny'}
)(forces_x, forces_y, forces_z)
x_final = xmap(lambda a,b:a+b,
in_axes=(['sx','tx','sy','ty','z',...],
['sx','tx','sy','ty','z',...]),
out_axes=['sx','tx','sy','ty','z',...],
axis_resources={'sx': 'nx', 'sy': 'ny'}
)(dx,pos)
# Painting final field
res = pcic_paint(jnp.zeros([2,256+halo_size*2,2,256+halo_size*2,512]), x_final)
res = halo_reduce(res)
res = pireshape(res).block_until_ready()
In [160]:
figure(figsize=[10,10])
imshow(res.sum(axis=-1))
Out[160]:
In [141]:
figure(figsize=[10,10])
imshow((0+initial_conditions).real.sum(axis=-1))
Out[141]:
In [ ]: