forked from Aquila-Consortium/JaxPM_highres
update data
This commit is contained in:
parent
137f4e5099
commit
ebab27b4c3
2 changed files with 91 additions and 82 deletions
|
@ -6,7 +6,7 @@ from jax.experimental.maps import xmap
|
||||||
from jax.experimental.pjit import pjit, PartitionSpec
|
from jax.experimental.pjit import pjit, PartitionSpec
|
||||||
|
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
import jaxpm as jpm
|
import jaxpm.painting as paint
|
||||||
|
|
||||||
# TODO: add a way to configure axis resources from command line
|
# TODO: add a way to configure axis resources from command line
|
||||||
axis_resources = {'x': 'nx', 'y': 'ny'}
|
axis_resources = {'x': 'nx', 'y': 'ny'}
|
||||||
|
@ -24,7 +24,7 @@ def stack3d(a, b, c):
|
||||||
|
|
||||||
|
|
||||||
@partial(xmap,
|
@partial(xmap,
|
||||||
in_axes=({0: 'x', 2: 'y'}),
|
in_axes=({0: 'x', 2: 'y'},[...]),
|
||||||
out_axes=({0: 'x', 2: 'y'}),
|
out_axes=({0: 'x', 2: 'y'}),
|
||||||
axis_resources=axis_resources)
|
axis_resources=axis_resources)
|
||||||
def scalar_multiply(a, factor):
|
def scalar_multiply(a, factor):
|
||||||
|
@ -41,8 +41,8 @@ def add(a, b):
|
||||||
|
|
||||||
|
|
||||||
@partial(xmap,
|
@partial(xmap,
|
||||||
in_axes=['x', 'y'],
|
in_axes=['x', 'y',...],
|
||||||
out_axes=['x', 'y'],
|
out_axes=['x', 'y',...],
|
||||||
axis_resources=axis_resources)
|
axis_resources=axis_resources)
|
||||||
def fft3d(mesh):
|
def fft3d(mesh):
|
||||||
""" Performs a 3D complex Fourier transform
|
""" Performs a 3D complex Fourier transform
|
||||||
|
@ -62,8 +62,8 @@ def fft3d(mesh):
|
||||||
|
|
||||||
|
|
||||||
@partial(xmap,
|
@partial(xmap,
|
||||||
in_axes=['x', 'y'],
|
in_axes=['x', 'y',...],
|
||||||
out_axes=['x', 'y'],
|
out_axes=['x', 'y',...],
|
||||||
axis_resources=axis_resources)
|
axis_resources=axis_resources)
|
||||||
def ifft3d(mesh):
|
def ifft3d(mesh):
|
||||||
mesh = jnp.fft.ifft(mesh)
|
mesh = jnp.fft.ifft(mesh)
|
||||||
|
@ -72,24 +72,25 @@ def ifft3d(mesh):
|
||||||
mesh = lax.all_to_all(mesh, 'x', 0, 0)
|
mesh = lax.all_to_all(mesh, 'x', 0, 0)
|
||||||
return jnp.fft.ifft(mesh).real
|
return jnp.fft.ifft(mesh).real
|
||||||
|
|
||||||
|
def normal(key, shape=[]):
|
||||||
@partial(xmap,
|
@partial(xmap,
|
||||||
in_axes=['x', 'y'],
|
in_axes=['x', 'y',...],
|
||||||
out_axes={0: 'x', 2: 'y'},
|
out_axes={0: 'x', 2: 'y'},
|
||||||
axis_resources=axis_resources)
|
axis_resources=axis_resources)
|
||||||
def normal(key, shape):
|
def fn(key):
|
||||||
""" Generate a distributed random normal distributions
|
""" Generate a distributed random normal distributions
|
||||||
Args:
|
Args:
|
||||||
key: array of random keys with same layout as computational mesh
|
key: array of random keys with same layout as computational mesh
|
||||||
shape: logical shape of array to sample
|
shape: logical shape of array to sample
|
||||||
"""
|
"""
|
||||||
return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'],
|
return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'],
|
||||||
shape[1]//mesh_size['ny']]+shape[2:])
|
shape[1]//mesh_size['ny']]+shape[2:])
|
||||||
|
return fn(key)
|
||||||
|
|
||||||
|
|
||||||
@partial(xmap,
|
@partial(xmap,
|
||||||
in_axes=(['x', 'y', ...],
|
in_axes=(['x', 'y', ...],
|
||||||
[['x'], ['y'], ...]),
|
[['x'], ['y'], [...]], [...], [...]),
|
||||||
out_axes=['x', 'y', ...],
|
out_axes=['x', 'y', ...],
|
||||||
axis_resources=axis_resources)
|
axis_resources=axis_resources)
|
||||||
@jax.jit
|
@jax.jit
|
||||||
|
@ -124,78 +125,85 @@ def meshgrid(x, y, z):
|
||||||
""" Generates a mesh grid of appropriate size for the
|
""" Generates a mesh grid of appropriate size for the
|
||||||
computational mesh we have.
|
computational mesh we have.
|
||||||
"""
|
"""
|
||||||
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
|
return jnp.stack(jnp.meshgrid(x,
|
||||||
|
y,
|
||||||
|
z), axis=-1)
|
||||||
|
|
||||||
|
|
||||||
@partial(xmap,
|
|
||||||
in_axes=({0: 'x', 2: 'y'}),
|
|
||||||
out_axes=({0: 'x', 2: 'y'}),
|
|
||||||
axis_resources=axis_resources)
|
|
||||||
def cic_paint(pos, mesh_shape, halo_size=0):
|
def cic_paint(pos, mesh_shape, halo_size=0):
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=({0: 'x', 2: 'y'}),
|
||||||
|
out_axes=({0: 'x', 2: 'y'}),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def fn(pos):
|
||||||
|
|
||||||
mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size,
|
mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size,
|
||||||
mesh_shape[1]//mesh_size['ny']+2*halo_size]
|
mesh_shape[1]//mesh_size['ny']+2*halo_size]
|
||||||
+ mesh_shape[2:])
|
+ mesh_shape[2:])
|
||||||
|
|
||||||
# Paint particles
|
# Paint particles
|
||||||
mesh = jpm.cic_paint(mesh, pos.reshape(-1, 3) +
|
mesh = paint.cic_paint(mesh, pos.reshape(-1, 3) +
|
||||||
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||||
|
|
||||||
# Perform halo exchange
|
# Perform halo exchange
|
||||||
# Halo exchange along x
|
# Halo exchange along x
|
||||||
left = lax.pshuffle(mesh[-halo_size:],
|
left = lax.pshuffle(mesh[-halo_size:],
|
||||||
perm=range(mesh_size['nx'])[::-1],
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
axis_name='x')
|
axis_name='x')
|
||||||
right = lax.pshuffle(mesh[:halo_size],
|
right = lax.pshuffle(mesh[:halo_size],
|
||||||
perm=range(mesh_size['nx'])[::-1],
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
axis_name='x')
|
axis_name='x')
|
||||||
mesh = mesh.at[:halo_size].add(left)
|
mesh = mesh.at[:halo_size].add(left)
|
||||||
mesh = mesh.at[-halo_size:].add(right)
|
mesh = mesh.at[-halo_size:].add(right)
|
||||||
|
|
||||||
# Halo exchange along y
|
# Halo exchange along y
|
||||||
left = lax.pshuffle(mesh[:, -halo_size:],
|
left = lax.pshuffle(mesh[:, -halo_size:],
|
||||||
perm=range(mesh_size['ny'])[::-1],
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
axis_name='y')
|
axis_name='y')
|
||||||
right = lax.pshuffle(mesh[:, :halo_size],
|
right = lax.pshuffle(mesh[:, :halo_size],
|
||||||
perm=range(mesh_size['ny'])[::-1],
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
axis_name='y')
|
axis_name='y')
|
||||||
mesh = mesh.at[:, :halo_size].add(left)
|
mesh = mesh.at[:, :halo_size].add(left)
|
||||||
mesh = mesh.at[:, -halo_size:].add(right)
|
mesh = mesh.at[:, -halo_size:].add(right)
|
||||||
|
|
||||||
# removing halo and returning mesh
|
# removing halo and returning mesh
|
||||||
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||||
|
|
||||||
|
return fn(pos)
|
||||||
|
|
||||||
@partial(xmap,
|
def cic_read(mesh, pos, halo_size=0):
|
||||||
in_axes=({0: 'x', 2: 'y'},
|
@partial(xmap,
|
||||||
{0: 'x', 2: 'y'}),
|
in_axes=({0: 'x', 2: 'y'},
|
||||||
out_axes=({0: 'x', 2: 'y'}),
|
{0: 'x', 2: 'y'},),
|
||||||
axis_resources=axis_resources)
|
out_axes=({0: 'x', 2: 'y'}),
|
||||||
def cic_read(mesh, pos, halo_size):
|
axis_resources=axis_resources)
|
||||||
|
def fn(mesh, pos):
|
||||||
|
|
||||||
# Halo exchange to grab neighboring borders
|
# Halo exchange to grab neighboring borders
|
||||||
# Exchange along x
|
# Exchange along x
|
||||||
left = lax.pshuffle(mesh[-halo_size:],
|
left = lax.pshuffle(mesh[-halo_size:],
|
||||||
perm=range(mesh_size['nx'])[::-1],
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
axis_name='x')
|
axis_name='x')
|
||||||
right = lax.pshuffle(mesh[:halo_size],
|
right = lax.pshuffle(mesh[:halo_size],
|
||||||
perm=range(mesh_size['nx'])[::-1],
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
axis_name='x')
|
axis_name='x')
|
||||||
mesh = jnp.concatenate([left, mesh, right], axis=0)
|
mesh = jnp.concatenate([left, mesh, right], axis=0)
|
||||||
# Exchange along y
|
# Exchange along y
|
||||||
left = lax.pshuffle(mesh[:, -halo_size:],
|
left = lax.pshuffle(mesh[:, -halo_size:],
|
||||||
perm=range(mesh_size['ny'])[::-1],
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
axis_name='y')
|
axis_name='y')
|
||||||
right = lax.pshuffle(mesh[:, :halo_size],
|
right = lax.pshuffle(mesh[:, :halo_size],
|
||||||
perm=range(mesh_size['ny'])[::-1],
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
axis_name='y')
|
axis_name='y')
|
||||||
mesh = jnp.concatenate([left, mesh, right], axis=1)
|
mesh = jnp.concatenate([left, mesh, right], axis=1)
|
||||||
|
|
||||||
# Reading field at particles positions
|
# Reading field at particles positions
|
||||||
res = jpm.painting.cic_read(mesh, pos.reshape(-1, 3) +
|
res = paint.cic_read(mesh, pos.reshape(-1, 3) +
|
||||||
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||||
|
|
||||||
return res
|
return res.reshape(pos.shape[:-1])
|
||||||
|
|
||||||
|
return fn(mesh, pos)
|
||||||
|
|
||||||
|
|
||||||
@partial(pjit,
|
@partial(pjit,
|
||||||
|
|
|
@ -16,7 +16,7 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
|
||||||
"""
|
"""
|
||||||
if mesh_shape is None:
|
if mesh_shape is None:
|
||||||
mesh_shape = delta_k.shape
|
mesh_shape = delta_k.shape
|
||||||
kvec = [k.squeeze() for k in fftk(mesh_shape)]
|
kvec = [k.squeeze() for k in fftk(mesh_shape, symmetric=False)]
|
||||||
|
|
||||||
if delta_k is None:
|
if delta_k is None:
|
||||||
delta = dops.cic_paint(positions, mesh_shape, halo_size)
|
delta = dops.cic_paint(positions, mesh_shape, halo_size)
|
||||||
|
@ -38,7 +38,7 @@ def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Sample normal field
|
# Sample normal field
|
||||||
field = dops.normal(seed, mesh_shape)
|
field = dops.normal(seed, shape=mesh_shape)
|
||||||
|
|
||||||
# Go to Fourier space
|
# Go to Fourier space
|
||||||
field = dops.fft3d(dops.reshape_split_to_dense(field))
|
field = dops.fft3d(dops.reshape_split_to_dense(field))
|
||||||
|
@ -64,8 +64,9 @@ def lpt(cosmo, initial_conditions, positions, a):
|
||||||
Computes first order LPT displacement
|
Computes first order LPT displacement
|
||||||
"""
|
"""
|
||||||
initial_force = pm_forces(positions, delta_k=initial_conditions)
|
initial_force = pm_forces(positions, delta_k=initial_conditions)
|
||||||
|
print(initial_force.shape)
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
dx = dops.scalar_multiply(initial_force * growth_factor(cosmo, a))
|
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
|
||||||
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
|
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
|
||||||
jnp.sqrt(jc.background.Esqr(cosmo, a)))
|
jnp.sqrt(jc.background.Esqr(cosmo, a)))
|
||||||
return dx, p
|
return dx, p
|
||||||
|
|
Loading…
Add table
Reference in a new issue