JaxPM/notebooks/ParallelNbody-LPT.ipynb
2022-10-19 17:37:38 -07:00

1.4 MiB

In [1]:
%pylab inline
import jax
import jax.numpy as jnp
import jax.lax as lax
from jax.experimental.maps import xmap
from jax.experimental.maps import Mesh
from jax.experimental.pjit import PartitionSpec, pjit
from jaxpm.pm import cic_paint, cic_read, fftk
from functools import partial
import jax_cosmo as jc
Populating the interactive namespace from numpy and matplotlib
In [115]:
nc=512
boxsize=1024. # Mpx/h
halo_size=16
In [151]:
def cic_paint(mesh, positions):
  """ Paints positions onto mesh
  mesh: [nx, ny, nz]
  positions: [npart, 3]
  """
  positions = jnp.expand_dims(positions, 1)
  floor = jnp.floor(positions)
  connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], 
                           [0., 0, 1], [1., 1, 0], [1., 0, 1], 
                           [0., 1, 1], [1., 1, 1]]])

  neighboor_coords = floor + connection
  kernel = 1. - jnp.abs(positions - neighboor_coords)
  kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]  

  neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), mesh.shape[-1])

  dnums = jax.lax.ScatterDimensionNumbers(
    update_window_dims=(),
    inserted_window_dims=(0, 1, 2),
    scatter_dims_to_operand_dims=(0, 1, 2))
  mesh = lax.scatter_add(mesh, 
                         neighboor_coords, 
                         kernel.reshape([-1,8]),
                         dnums)
  return mesh

def cic_read(mesh, positions):
  """ Paints positions onto mesh
  mesh: [nx, ny, nz]
  positions: [npart, 3]
  """  
  positions = jnp.expand_dims(positions, 1)
  floor = jnp.floor(positions)
  connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], 
                           [0., 0, 1], [1., 1, 0], [1., 0, 1], 
                           [0., 1, 1], [1., 1, 1]]])

  neighboor_coords = floor + connection
  kernel = 1. - jnp.abs(positions - neighboor_coords)
  kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]  

  #neighboor_coords = neighboor_coords.reshape([-1,8,3]).astype('int32')

  neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), mesh.shape[-1])

  return (mesh[neighboor_coords[...,0], 
               neighboor_coords[...,1], 
               neighboor_coords[...,3]]*kernel).sum(axis=-1)
In [152]:
# Defining the main operations
@partial(xmap,
         in_axes={0:'x', 1:'y'},
         out_axes=['x','y',...],
         axis_sizes={'x':nc, 'y':nc},
         axis_resources={'x': 'nx', 'y':'ny',
                         'key_x':'nx', 'key_y':'ny'})
def pnormal(key):
    return jax.random.normal(key, shape=[nc])

@partial(xmap,
         in_axes={0:'x', 1:'y'},
         out_axes=['x','y',...],
         axis_resources={'x': 'nx', 'y': 'ny'})
def pfft3d(mesh):
    # [x, y, z]
    mesh = jnp.fft.fft(mesh)
    mesh = lax.all_to_all(mesh, 'x', 0, 0) # [z, y, x]
    mesh = jnp.fft.fft(mesh)
    mesh = lax.all_to_all(mesh, 'y', 0, 0) # [z, x, y]
    return jnp.fft.fft(mesh)

@partial(xmap,
         in_axes={0:'x', 1:'y'},
         out_axes=['x','y',...],
         axis_resources={'x': 'nx', 'y': 'ny'})
def pifft3d(mesh):
    mesh = jnp.fft.ifft(mesh) 
    mesh = lax.all_to_all(mesh, 'y', 0, 0)
    mesh = jnp.fft.ifft(mesh)
    mesh = lax.all_to_all(mesh, 'x', 0, 0) 
    return jnp.fft.ifft(mesh).real
In [153]:
@partial(xmap,
         in_axes=(['x','y',...],
                  ['x'],
                  ['y'],
                  [...],[...],[...]),
         out_axes=['x','y',...],
         axis_resources={'x': 'nx', 'y': 'ny'})
def cwise_fn(kfield, kx, ky, kz, k, pk):
    kk = jnp.sqrt((kx / boxsize * nc)**2 + (ky / boxsize * nc)**2 +
                 (kz / boxsize * nc)**2)
    pkmesh = jc.scipy.interpolate.interp(kk, k, pk)
    return kfield*(pkmesh*nc**3/boxsize**3)**0.5

def get_initial_cond(cosmo, seed):
    # Get real density field
    linear = pnormal(jax.random.split(seed, nc*nc).reshape(nc,nc,-1))
    lineark = pfft3d(linear)
    
    k = jnp.logspace(-4, 2, 256)
    pk = jc.power.linear_matter_power(cosmo, k)
    kvec = fftk([nc,nc,nc], symmetric=False)
    
    lineark = cwise_fn(lineark, kvec[0].squeeze(),kvec[1].squeeze(),kvec[2].squeeze(), k, pk)
    
    return pifft3d(lineark)
In [154]:
key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape([2,2,2])

# We reshape all our devices to the mesh shape we want
devices = np.array(jax.local_devices()).reshape((2, 2))

cosmo = jc.Planck15()

Ok, cool, now let's implement LPT

In [155]:
@partial(xmap,
         in_axes=([...]),
         out_axes={0:'sx', 2:'sy'},
         axis_sizes={'sx':2, 'sy':2},
         axis_resources={'sx': 'nx', 'sy': 'ny'})
def pmeshgrid(x, y, z):
    return jnp.stack(jnp.meshgrid(x,y,z),axis=-1)

@partial(xmap,
         in_axes=(['x','y','z'],
                  ['x'], ['y'], ['z']),
         out_axes=(['x','y','z'],
                   ['x','y','z'],
                   ['x','y','z']),
         axis_resources={'x': 'nx', 'y': 'ny'})
def papply_gradient_laplace(kfield, kx, ky, kz):
    kk = (kx**2 + ky**2 + kz**2)
    kernel = jnp.where(kk == 0, 1., 1./kk)
    return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), 
            kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), 
            kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))) 

preshape = pjit(lambda x: x.reshape([2, nc//2, 2, nc//2]+list(x.shape[2:])), 
                in_axis_resources=PartitionSpec('nx','ny'),
                out_axis_resources=PartitionSpec('nx', None, 'ny', None))

pireshape = pjit(lambda x: x.reshape([nc, nc]+list(x.shape[4:])), 
                in_axis_resources=PartitionSpec('nx', None, 'ny', None),
                out_axis_resources=PartitionSpec('nx','ny'))

pcic_read = xmap(lambda mesh, pos: cic_read(mesh, pos.reshape(-1,3)).reshape(pos.shape[:-1]),
         in_axes=({0:'sx',2:'sy'},
                  {0:'sx',2:'sy'}),
         out_axes=({0:'sx',2:'sy'}),
         axis_resources={'sx': 'nx', 'sy': 'ny'})

pcic_paint = xmap(lambda mesh, pos, halo_size=halo_size: cic_paint(mesh, pos.reshape(-1,3)+jnp.array([halo_size,halo_size,0]).reshape([-1,3])),
         in_axes=({0:'sx',2:'sy'},
                  {0:'sx',2:'sy'}),
         out_axes=({0:'sx',2:'sy'}),
         axis_resources={'sx': 'nx', 'sy': 'ny'})

@partial(xmap,
         in_axes=({0:'sx',2:'sy'},[...]),
         out_axes={0:'sx',2:'sy'},
         axis_resources={'sx': 'nx', 'sy': 'ny'})
def pad_mesh(mesh, halo_size=halo_size):
    return jnp.pad(messh,[halo_size]*3)

@partial(xmap,
         in_axes=({0:'sx',2:'sy'}),
         out_axes={0:'sx',2:'sy'},
         axis_resources={'sx': 'nx', 'sy': 'ny'})
def halo_reduce(mesh, halo_size=halo_size):
    for axis_ind, axis_name in enumerate(['sx', 'sy']):        
        # Split the array
        left_margin, center, right_margin = mesh.split([2*halo_size, nc//2 ], axis_ind)
        
        # Perform halo exchange
        left = lax.pshuffle(right_margin, perm=[1,0], axis_name=axis_name)
        right =lax.pshuffle(left_margin, perm=[1,0], axis_name=axis_name)
        if axis_ind==0:
            mesh = mesh.at[:2*halo_size].add(left)
            mesh = mesh.at[-2*halo_size:].add(right)
        else:
            mesh = mesh.at[:,:2*halo_size].add(left)
            mesh = mesh.at[:,-2*halo_size:].add(right)
            
    # removing leftovers
    return mesh[halo_size:-halo_size,halo_size:-halo_size]
In [171]:
with Mesh(devices, ('nx', 'ny')):
    
    initial_conditions = get_initial_cond(cosmo, key)
    
    # Create the particles
    pos = pmeshgrid(jnp.arange(nc//2), jnp.arange(nc//2), jnp.arange(nc))
    
    # Take the FFT of the field
    lineark = pfft3d(initial_conditions)
    
    # Apply the laplace kernel
    kvec = fftk([nc,nc,nc], symmetric=False)
    kforces = papply_gradient_laplace(lineark,
                                      kvec[0].squeeze(), 
                                      kvec[1].squeeze(), 
                                      kvec[2].squeeze())
    
    # Inverse Fourier Transform
    forces_x = pcic_read(preshape(pifft3d(kforces[0])), pos)
    forces_y = pcic_read(preshape(pifft3d(kforces[1])), pos)
    forces_z = pcic_read(preshape(pifft3d(kforces[2])), pos)
    
    # Read the forces at particle positions
    dx = xmap(lambda a,b,c: jnp.array([a,b,c]),
              in_axes=(['sx','tx','sy','ty','z',...],
                       ['sx','tx','sy','ty','z',...],
                       ['sx','tx','sy','ty','z',...]),
              out_axes=['sx','tx','sy','ty','z',...],
              axis_resources={'sx': 'nx', 'sy': 'ny'}
             )(forces_x, forces_y, forces_z)
    
    x_final = xmap(lambda a,b:a+b,
                    in_axes=(['sx','tx','sy','ty','z',...],
                       ['sx','tx','sy','ty','z',...]),
                    out_axes=['sx','tx','sy','ty','z',...], 
                   axis_resources={'sx': 'nx', 'sy': 'ny'}
                  )(dx,pos)
    
    # Painting final field
    res = pcic_paint(jnp.zeros([2,256+halo_size*2,2,256+halo_size*2,512]), x_final)
    res = halo_reduce(res)
    res = pireshape(res).block_until_ready()
In [160]:
figure(figsize=[10,10])
imshow(res.sum(axis=-1))
Out[160]:
<matplotlib.image.AxesImage at 0x15343c2cceb0>
No description has been provided for this image
In [141]:
figure(figsize=[10,10])
imshow((0+initial_conditions).real.sum(axis=-1))
Out[141]:
<matplotlib.image.AxesImage at 0x1535400f9700>
No description has been provided for this image
In [ ]: