update data

This commit is contained in:
EiffL 2022-10-20 21:03:15 -07:00
parent 137f4e5099
commit ebab27b4c3
2 changed files with 91 additions and 82 deletions

View file

@ -6,7 +6,7 @@ from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit, PartitionSpec
import jax_cosmo as jc
import jaxpm as jpm
import jaxpm.painting as paint
# TODO: add a way to configure axis resources from command line
axis_resources = {'x': 'nx', 'y': 'ny'}
@ -24,7 +24,7 @@ def stack3d(a, b, c):
@partial(xmap,
in_axes=({0: 'x', 2: 'y'}),
in_axes=({0: 'x', 2: 'y'},[...]),
out_axes=({0: 'x', 2: 'y'}),
axis_resources=axis_resources)
def scalar_multiply(a, factor):
@ -41,8 +41,8 @@ def add(a, b):
@partial(xmap,
in_axes=['x', 'y'],
out_axes=['x', 'y'],
in_axes=['x', 'y',...],
out_axes=['x', 'y',...],
axis_resources=axis_resources)
def fft3d(mesh):
""" Performs a 3D complex Fourier transform
@ -62,8 +62,8 @@ def fft3d(mesh):
@partial(xmap,
in_axes=['x', 'y'],
out_axes=['x', 'y'],
in_axes=['x', 'y',...],
out_axes=['x', 'y',...],
axis_resources=axis_resources)
def ifft3d(mesh):
mesh = jnp.fft.ifft(mesh)
@ -72,12 +72,12 @@ def ifft3d(mesh):
mesh = lax.all_to_all(mesh, 'x', 0, 0)
return jnp.fft.ifft(mesh).real
def normal(key, shape=[]):
@partial(xmap,
in_axes=['x', 'y'],
in_axes=['x', 'y',...],
out_axes={0: 'x', 2: 'y'},
axis_resources=axis_resources)
def normal(key, shape):
def fn(key):
""" Generate a distributed random normal distributions
Args:
key: array of random keys with same layout as computational mesh
@ -85,11 +85,12 @@ def normal(key, shape):
"""
return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'],
shape[1]//mesh_size['ny']]+shape[2:])
return fn(key)
@partial(xmap,
in_axes=(['x', 'y', ...],
[['x'], ['y'], ...]),
[['x'], ['y'], [...]], [...], [...]),
out_axes=['x', 'y', ...],
axis_resources=axis_resources)
@jax.jit
@ -124,21 +125,24 @@ def meshgrid(x, y, z):
""" Generates a mesh grid of appropriate size for the
computational mesh we have.
"""
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
return jnp.stack(jnp.meshgrid(x,
y,
z), axis=-1)
def cic_paint(pos, mesh_shape, halo_size=0):
@partial(xmap,
in_axes=({0: 'x', 2: 'y'}),
out_axes=({0: 'x', 2: 'y'}),
axis_resources=axis_resources)
def cic_paint(pos, mesh_shape, halo_size=0):
def fn(pos):
mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size,
mesh_shape[1]//mesh_size['ny']+2*halo_size]
+ mesh_shape[2:])
# Paint particles
mesh = jpm.cic_paint(mesh, pos.reshape(-1, 3) +
mesh = paint.cic_paint(mesh, pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
# Perform halo exchange
@ -165,13 +169,15 @@ def cic_paint(pos, mesh_shape, halo_size=0):
# removing halo and returning mesh
return mesh[halo_size:-halo_size, halo_size:-halo_size]
return fn(pos)
def cic_read(mesh, pos, halo_size=0):
@partial(xmap,
in_axes=({0: 'x', 2: 'y'},
{0: 'x', 2: 'y'}),
{0: 'x', 2: 'y'},),
out_axes=({0: 'x', 2: 'y'}),
axis_resources=axis_resources)
def cic_read(mesh, pos, halo_size):
def fn(mesh, pos):
# Halo exchange to grab neighboring borders
# Exchange along x
@ -192,10 +198,12 @@ def cic_read(mesh, pos, halo_size):
mesh = jnp.concatenate([left, mesh, right], axis=1)
# Reading field at particles positions
res = jpm.painting.cic_read(mesh, pos.reshape(-1, 3) +
res = paint.cic_read(mesh, pos.reshape(-1, 3) +
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
return res
return res.reshape(pos.shape[:-1])
return fn(mesh, pos)
@partial(pjit,

View file

@ -16,7 +16,7 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
"""
if mesh_shape is None:
mesh_shape = delta_k.shape
kvec = [k.squeeze() for k in fftk(mesh_shape)]
kvec = [k.squeeze() for k in fftk(mesh_shape, symmetric=False)]
if delta_k is None:
delta = dops.cic_paint(positions, mesh_shape, halo_size)
@ -38,7 +38,7 @@ def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
"""
# Sample normal field
field = dops.normal(seed, mesh_shape)
field = dops.normal(seed, shape=mesh_shape)
# Go to Fourier space
field = dops.fft3d(dops.reshape_split_to_dense(field))
@ -64,8 +64,9 @@ def lpt(cosmo, initial_conditions, positions, a):
Computes first order LPT displacement
"""
initial_force = pm_forces(positions, delta_k=initial_conditions)
print(initial_force.shape)
a = jnp.atleast_1d(a)
dx = dops.scalar_multiply(initial_force * growth_factor(cosmo, a))
dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
jnp.sqrt(jc.background.Esqr(cosmo, a)))
return dx, p