mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
format
This commit is contained in:
parent
20fe25c324
commit
1f5c619531
10 changed files with 290 additions and 210 deletions
|
@ -79,6 +79,7 @@ def slice_unpad_impl(x, pad_width):
|
||||||
|
|
||||||
return x[tuple(unpad_slice)]
|
return x[tuple(unpad_slice)]
|
||||||
|
|
||||||
|
|
||||||
def slice_pad_impl(x, pad_width):
|
def slice_pad_impl(x, pad_width):
|
||||||
return jax.tree.map(lambda x: jnp.pad(x, pad_width), x)
|
return jax.tree.map(lambda x: jnp.pad(x, pad_width), x)
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax.lax import FftType
|
from jax.lax import FftType
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
from jaxdecomp import fftfreq3d, get_output_specs
|
from jaxdecomp import fftfreq3d, get_output_specs
|
||||||
import jax
|
|
||||||
from jaxpm.distributed import autoshmap
|
from jaxpm.distributed import autoshmap
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +26,8 @@ def fftk(k_array):
|
||||||
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||||
|
|
||||||
def pk_fn(input):
|
def pk_fn(input):
|
||||||
return jax.tree.map(lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)
|
return jax.tree.map(
|
||||||
|
lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)
|
||||||
|
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
specs = sharding.spec if sharding is not None else P()
|
specs = sharding.spec if sharding is not None else P()
|
||||||
|
@ -61,7 +63,8 @@ def gradient_kernel(kvec, direction, order=1):
|
||||||
return wts
|
return wts
|
||||||
else:
|
else:
|
||||||
w = kvec[direction]
|
w = kvec[direction]
|
||||||
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)), w)
|
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)),
|
||||||
|
w)
|
||||||
wts = a * 1j
|
wts = a * 1j
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
@ -85,11 +88,14 @@ def invlaplace_kernel(kvec, fd=False):
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
"""
|
"""
|
||||||
if fd:
|
if fd:
|
||||||
kk = sum(jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki) for ki in kvec)
|
kk = sum(
|
||||||
|
jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki)
|
||||||
|
for ki in kvec)
|
||||||
else:
|
else:
|
||||||
kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec)
|
kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec)
|
||||||
kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk)
|
kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk)
|
||||||
return jax.tree.map(lambda x , y : -jnp.where(y == 0, 0, 1 / x), kk_nozeros, kk)
|
return jax.tree.map(lambda x, y: -jnp.where(y == 0, 0, 1 / x), kk_nozeros,
|
||||||
|
kk)
|
||||||
|
|
||||||
|
|
||||||
def longrange_kernel(kvec, r_split):
|
def longrange_kernel(kvec, r_split):
|
||||||
|
@ -131,7 +137,10 @@ def cic_compensation(kvec):
|
||||||
wts: array
|
wts: array
|
||||||
Complex kernel values
|
Complex kernel values
|
||||||
"""
|
"""
|
||||||
kwts = [jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i]) for i in range(3)]
|
kwts = [
|
||||||
|
jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i])
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
|
|
@ -19,28 +19,36 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
positions = positions.reshape([-1, 3])
|
positions = positions.reshape([-1, 3])
|
||||||
positions = jax.tree.map(lambda p : jnp.expand_dims(p , 1) , positions)
|
positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
|
||||||
floor = jax.tree.map(jnp.floor , positions)
|
floor = jax.tree.map(jnp.floor, positions)
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||||
|
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
kernel = 1. - jax.tree.map(jnp.abs , (positions - neighboor_coords))
|
kernel = 1. - jax.tree.map(jnp.abs, (positions - neighboor_coords))
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
if jax.tree.all(jax.tree.map(jnp.isscalar, weight)):
|
if jax.tree.all(jax.tree.map(jnp.isscalar, weight)):
|
||||||
kernel = jax.tree.map(lambda k , w : jnp.multiply(jnp.expand_dims(w, axis=-1)
|
kernel = jax.tree.map(
|
||||||
, k) , kernel , weight)
|
lambda k, w: jnp.multiply(jnp.expand_dims(w, axis=-1), k),
|
||||||
|
kernel, weight)
|
||||||
else:
|
else:
|
||||||
kernel = jax.tree.map(lambda k , w : jnp.multiply(w.reshape(*positions.shape[:-1]) , k) , kernel , weight)
|
kernel = jax.tree.map(
|
||||||
|
lambda k, w: jnp.multiply(w.reshape(*positions.shape[:-1]), k),
|
||||||
|
kernel, weight)
|
||||||
|
|
||||||
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)) , neighboor_coords)
|
neighboor_coords = jax.tree.map(
|
||||||
|
lambda nc: jnp.mod(
|
||||||
|
nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)
|
||||||
|
), neighboor_coords)
|
||||||
|
|
||||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||||
inserted_window_dims=(0, 1, 2),
|
inserted_window_dims=(0, 1, 2),
|
||||||
scatter_dims_to_operand_dims=(0, 1,
|
scatter_dims_to_operand_dims=(0, 1,
|
||||||
2))
|
2))
|
||||||
mesh = jax.tree.map(lambda g , nc , k : lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums) , grid_mesh , neighboor_coords , kernel)
|
mesh = jax.tree.map(
|
||||||
|
lambda g, nc, k: lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums),
|
||||||
|
grid_mesh, neighboor_coords, kernel)
|
||||||
|
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
@ -49,7 +57,8 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||||
|
|
||||||
positions_structure = jax.tree.structure(positions)
|
positions_structure = jax.tree.structure(positions)
|
||||||
grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh))
|
grid_mesh = jax.tree.unflatten(positions_structure,
|
||||||
|
jax.tree.leaves(grid_mesh))
|
||||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||||
|
|
||||||
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
||||||
|
@ -79,24 +88,27 @@ def _cic_read_impl(grid_mesh, positions):
|
||||||
# Reshape positions to a flat list of 3D coordinates
|
# Reshape positions to a flat list of 3D coordinates
|
||||||
positions = positions.reshape([-1, 3])
|
positions = positions.reshape([-1, 3])
|
||||||
# Expand dimensions to calculate neighbor coordinates
|
# Expand dimensions to calculate neighbor coordinates
|
||||||
positions = jax.tree.map(lambda p : jnp.expand_dims(p, 1) , positions)
|
positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
|
||||||
# Floor the positions to get the base grid cell for each particle
|
# Floor the positions to get the base grid cell for each particle
|
||||||
floor = jax.tree.map(jnp.floor , positions)
|
floor = jax.tree.map(jnp.floor, positions)
|
||||||
# Define connections to calculate all neighbor coordinates
|
# Define connections to calculate all neighbor coordinates
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||||
# Calculate the 8 neighboring coordinates
|
# Calculate the 8 neighboring coordinates
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
# Calculate kernel weights based on distance from each neighboring coordinate
|
# Calculate kernel weights based on distance from each neighboring coordinate
|
||||||
kernel = 1. - jax.tree.map(jnp.abs , positions - neighboor_coords)
|
kernel = 1. - jax.tree.map(jnp.abs, positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
# Modulo operation to wrap around edges if necessary
|
# Modulo operation to wrap around edges if necessary
|
||||||
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.astype('int32')
|
neighboor_coords = jax.tree.map(
|
||||||
,jnp.array(grid_mesh.shape)) , neighboor_coords)
|
lambda nc: jnp.mod(nc.astype('int32'), jnp.array(grid_mesh.shape)),
|
||||||
|
neighboor_coords)
|
||||||
|
|
||||||
# Ensure grid_mesh shape is as expected
|
# Ensure grid_mesh shape is as expected
|
||||||
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
|
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
|
||||||
grid_mesh = jax.tree.map(lambda g , nc , k : g[nc[...,0], nc[...,1], nc[...,2]] * k , grid_mesh , neighboor_coords , kernel)
|
grid_mesh = jax.tree.map(
|
||||||
|
lambda g, nc, k: g[nc[..., 0], nc[..., 1], nc[..., 2]] * k, grid_mesh,
|
||||||
|
neighboor_coords, kernel)
|
||||||
return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
|
return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
|
||||||
|
|
||||||
|
|
||||||
|
@ -157,7 +169,9 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||||
halo_y, _ = halo_size[1]
|
halo_y, _ = halo_size[1]
|
||||||
|
|
||||||
original_shape = displacements.shape
|
original_shape = displacements.shape
|
||||||
particle_mesh = jax.tree.map(lambda x : jnp.zeros(x.shape[:-1], dtype=displacements.dtype), displacements)
|
particle_mesh = jax.tree.map(
|
||||||
|
lambda x: jnp.zeros(x.shape[:-1], dtype=displacements.dtype),
|
||||||
|
displacements)
|
||||||
if not jnp.isscalar(weight):
|
if not jnp.isscalar(weight):
|
||||||
if weight.shape != original_shape[:-1]:
|
if weight.shape != original_shape[:-1]:
|
||||||
raise ValueError("Weight shape must match particle shape")
|
raise ValueError("Weight shape must match particle shape")
|
||||||
|
@ -165,13 +179,18 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||||
weight = weight.flatten()
|
weight = weight.flatten()
|
||||||
# Padding is forced to be zero in a single gpu run
|
# Padding is forced to be zero in a single gpu run
|
||||||
|
|
||||||
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]),
|
a, b, c = jax.tree.map(
|
||||||
|
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]),
|
||||||
jnp.arange(x.shape[1]),
|
jnp.arange(x.shape[1]),
|
||||||
jnp.arange(x.shape[2]),
|
jnp.arange(x.shape[2]),
|
||||||
indexing='ij') , axis=0), particle_mesh)
|
indexing='ij'),
|
||||||
|
axis=0), particle_mesh)
|
||||||
|
|
||||||
particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh)
|
particle_mesh = jax.tree.map(lambda x: jnp.pad(x, halo_size),
|
||||||
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
|
particle_mesh)
|
||||||
|
pmid = jax.tree.map(
|
||||||
|
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
|
||||||
|
c)
|
||||||
return scatter(pmid.reshape([-1, 3]),
|
return scatter(pmid.reshape([-1, 3]),
|
||||||
displacements.reshape([-1, 3]),
|
displacements.reshape([-1, 3]),
|
||||||
particle_mesh,
|
particle_mesh,
|
||||||
|
@ -217,12 +236,16 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||||
jnp.arange(original_shape[1]),
|
jnp.arange(original_shape[1]),
|
||||||
jnp.arange(original_shape[2]),
|
jnp.arange(original_shape[2]),
|
||||||
indexing='ij')
|
indexing='ij')
|
||||||
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]),
|
a, b, c = jax.tree.map(
|
||||||
|
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]),
|
||||||
jnp.arange(original_shape[1]),
|
jnp.arange(original_shape[1]),
|
||||||
jnp.arange(original_shape[2]),
|
jnp.arange(original_shape[2]),
|
||||||
indexing='ij') , axis=0), grid_mesh)
|
indexing='ij'),
|
||||||
|
axis=0), grid_mesh)
|
||||||
|
|
||||||
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
|
pmid = jax.tree.map(
|
||||||
|
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
|
||||||
|
c)
|
||||||
pmid = pmid.reshape([-1, 3])
|
pmid = pmid.reshape([-1, 3])
|
||||||
disp = disp.reshape([-1, 3])
|
disp = disp.reshape([-1, 3])
|
||||||
|
|
||||||
|
|
|
@ -28,8 +28,8 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||||
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||||
new_cell_size, new_shape):
|
new_cell_size, new_shape):
|
||||||
"""Multilinear enmeshing."""
|
"""Multilinear enmeshing."""
|
||||||
base_indices = jax.tree.map(jnp.asarray , base_indices)
|
base_indices = jax.tree.map(jnp.asarray, base_indices)
|
||||||
displacements = jax.tree.map(jnp.asarray , displacements)
|
displacements = jax.tree.map(jnp.asarray, displacements)
|
||||||
with jax.experimental.enable_x64():
|
with jax.experimental.enable_x64():
|
||||||
cell_size = jnp.float64(
|
cell_size = jnp.float64(
|
||||||
cell_size) if new_cell_size is not None else jnp.array(
|
cell_size) if new_cell_size is not None else jnp.array(
|
||||||
|
@ -61,8 +61,8 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||||
new_displacements = particle_positions - new_indices * new_cell_size
|
new_displacements = particle_positions - new_indices * new_cell_size
|
||||||
|
|
||||||
if base_shape is not None:
|
if base_shape is not None:
|
||||||
new_displacements -= jax.tree.map(jnp.rint ,
|
new_displacements -= jax.tree.map(
|
||||||
new_displacements / grid_length
|
jnp.rint, new_displacements / grid_length
|
||||||
) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
||||||
|
|
||||||
new_indices = new_indices.astype(base_indices.dtype)
|
new_indices = new_indices.astype(base_indices.dtype)
|
||||||
|
@ -89,7 +89,7 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||||
if base_shape is not None:
|
if base_shape is not None:
|
||||||
new_indices %= base_shape
|
new_indices %= base_shape
|
||||||
|
|
||||||
weights = 1 - jax.tree.map(jnp.abs , new_displacements)
|
weights = 1 - jax.tree.map(jnp.abs, new_displacements)
|
||||||
|
|
||||||
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
||||||
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
||||||
|
@ -109,11 +109,15 @@ def _scatter_chunk(carry, chunk):
|
||||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||||
spatial_shape)
|
spatial_shape)
|
||||||
# scatter
|
# scatter
|
||||||
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
|
ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
|
||||||
|
ind)
|
||||||
mesh_structure = jax.tree.structure(mesh)
|
mesh_structure = jax.tree.structure(mesh)
|
||||||
val_flat = jax.tree.leaves(val)
|
val_flat = jax.tree.leaves(val)
|
||||||
val_tree = jax.tree.unflatten(mesh_structure, val_flat)
|
val_tree = jax.tree.unflatten(mesh_structure, val_flat)
|
||||||
mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac)
|
mesh = jax.tree.map(
|
||||||
|
lambda m, v, i, f: m.at[i].add(
|
||||||
|
jnp.multiply(jnp.expand_dims(v, axis=-1), f)), mesh, val_tree, ind,
|
||||||
|
frac)
|
||||||
carry = mesh, offset, cell_size, mesh_shape
|
carry = mesh, offset, cell_size, mesh_shape
|
||||||
return carry, None
|
return carry, None
|
||||||
|
|
||||||
|
@ -127,8 +131,8 @@ def scatter(pmid,
|
||||||
cell_size=1.):
|
cell_size=1.):
|
||||||
|
|
||||||
ptcl_num, spatial_ndim = pmid.shape
|
ptcl_num, spatial_ndim = pmid.shape
|
||||||
val = jax.tree.map(jnp.asarray , val)
|
val = jax.tree.map(jnp.asarray, val)
|
||||||
mesh = jax.tree.map(jnp.asarray , mesh)
|
mesh = jax.tree.map(jnp.asarray, mesh)
|
||||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||||
carry = mesh, offset, cell_size, mesh.shape
|
carry = mesh, offset, cell_size, mesh.shape
|
||||||
if remainder is not None:
|
if remainder is not None:
|
||||||
|
@ -151,9 +155,9 @@ def _chunk_cat(remainder_array, chunked_array):
|
||||||
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
||||||
ptcl_num, spatial_ndim = pmid.shape
|
ptcl_num, spatial_ndim = pmid.shape
|
||||||
|
|
||||||
mesh = jax.tree.map(jnp.asarray , mesh)
|
mesh = jax.tree.map(jnp.asarray, mesh)
|
||||||
|
|
||||||
val = jax.tree.map(jnp.asarray , val)
|
val = jax.tree.map(jnp.asarray, val)
|
||||||
|
|
||||||
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
||||||
raise ValueError('channel shape mismatch: '
|
raise ValueError('channel shape mismatch: '
|
||||||
|
@ -187,11 +191,15 @@ def _gather_chunk(carry, chunk):
|
||||||
spatial_shape)
|
spatial_shape)
|
||||||
|
|
||||||
# gather
|
# gather
|
||||||
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
|
ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
|
||||||
|
ind)
|
||||||
frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac)
|
frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac)
|
||||||
ind_structure = jax.tree.structure(ind)
|
ind_structure = jax.tree.structure(ind)
|
||||||
frac_structure = jax.tree.structure(frac)
|
frac_structure = jax.tree.structure(frac)
|
||||||
mesh_structure = jax.tree.structure(mesh)
|
mesh_structure = jax.tree.structure(mesh)
|
||||||
val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac)
|
val += jax.tree.map(
|
||||||
|
lambda m, i, f:
|
||||||
|
(m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1), mesh, ind,
|
||||||
|
frac)
|
||||||
|
|
||||||
return carry, val
|
return carry, val
|
||||||
|
|
19
jaxpm/pm.py
19
jaxpm/pm.py
|
@ -1,3 +1,4 @@
|
||||||
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
|
|
||||||
|
@ -7,7 +8,7 @@ from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
||||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||||
invlaplace_kernel, longrange_kernel)
|
invlaplace_kernel, longrange_kernel)
|
||||||
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||||
import jax
|
|
||||||
|
|
||||||
def pm_forces(positions,
|
def pm_forces(positions,
|
||||||
mesh_shape=None,
|
mesh_shape=None,
|
||||||
|
@ -52,10 +53,12 @@ def pm_forces(positions,
|
||||||
kvec, r_split=r_split)
|
kvec, r_split=r_split)
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
forces = [
|
forces = [
|
||||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k), positions)
|
||||||
) for i in range(3)]
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
forces = jax.tree.map(lambda x ,y ,z : jnp.stack([x,y,z], axis=-1), forces[0], forces[1], forces[2])
|
forces = jax.tree.map(lambda x, y, z: jnp.stack([x, y, z], axis=-1),
|
||||||
|
forces[0], forces[1], forces[2])
|
||||||
|
|
||||||
return forces
|
return forces
|
||||||
|
|
||||||
|
@ -73,8 +76,9 @@ def lpt(cosmo,
|
||||||
"""
|
"""
|
||||||
paint_absolute_pos = particles is not None
|
paint_absolute_pos = particles is not None
|
||||||
if particles is None:
|
if particles is None:
|
||||||
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
|
particles = jax.tree.map(
|
||||||
shape=(*ic.shape, 3)) , initial_conditions)
|
lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
|
||||||
|
initial_conditions)
|
||||||
|
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||||
|
@ -198,7 +202,8 @@ def make_diffrax_ode(mesh_shape,
|
||||||
# Computes the update of velocity (kick)
|
# Computes the update of velocity (kick)
|
||||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
return jax.tree.map(lambda dp , dv : jnp.stack([dp, dv],axis=0), dpos, dvel)
|
return jax.tree.map(lambda dp, dv: jnp.stack([dp, dv], axis=0), dpos,
|
||||||
|
dvel)
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax.scipy.stats import norm
|
from jax.scipy.stats import norm
|
||||||
from scipy.special import legendre
|
from scipy.special import legendre
|
||||||
import jax
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
||||||
'cross_correlation_coefficients', 'gaussian_smoothing'
|
'cross_correlation_coefficients', 'gaussian_smoothing'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -100,11 +101,12 @@ def power_spectrum(mesh,
|
||||||
n_bins = len(kavg) + 2
|
n_bins = len(kavg) + 2
|
||||||
|
|
||||||
# FFTs
|
# FFTs
|
||||||
meshk = jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho') , mesh)
|
meshk = jax.tree.map(lambda x: jnp.fft.fftn(x, norm='ortho'), mesh)
|
||||||
if mesh2 is None:
|
if mesh2 is None:
|
||||||
mmk = meshk.real**2 + meshk.imag**2
|
mmk = meshk.real**2 + meshk.imag**2
|
||||||
else:
|
else:
|
||||||
mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2)
|
mmk = meshk * jax.tree.map(
|
||||||
|
lambda x: jnp.fft.fftn(x, norm='ortho').conj(), mesh2)
|
||||||
|
|
||||||
# Sum powers
|
# Sum powers
|
||||||
pk = jnp.empty((len(poles), n_bins))
|
pk = jnp.empty((len(poles), n_bins))
|
||||||
|
|
|
@ -174,18 +174,22 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
|
||||||
|
|
||||||
return fpm_mesh
|
return fpm_mesh
|
||||||
|
|
||||||
|
|
||||||
def compare_sharding(sharding1, sharding2):
|
def compare_sharding(sharding1, sharding2):
|
||||||
|
|
||||||
def get_axis_size(sharding, idx):
|
def get_axis_size(sharding, idx):
|
||||||
axis_name = sharding.spec[idx]
|
axis_name = sharding.spec[idx]
|
||||||
if axis_name is None:
|
if axis_name is None:
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return sharding.mesh.shape[sharding.spec[idx]]
|
return sharding.mesh.shape[sharding.spec[idx]]
|
||||||
|
|
||||||
def get_pdims_from_sharding(sharding):
|
def get_pdims_from_sharding(sharding):
|
||||||
return tuple([get_axis_size(sharding, i) for i in range(len(sharding.spec))])
|
return tuple(
|
||||||
|
[get_axis_size(sharding, i) for i in range(len(sharding.spec))])
|
||||||
|
|
||||||
pdims1 = get_pdims_from_sharding(sharding1)
|
pdims1 = get_pdims_from_sharding(sharding1)
|
||||||
pdims2 = get_pdims_from_sharding(sharding2)
|
pdims2 = get_pdims_from_sharding(sharding2)
|
||||||
pdims1 = pdims1 + (1,) * (3 - len(pdims1))
|
pdims1 = pdims1 + (1, ) * (3 - len(pdims1))
|
||||||
pdims2 = pdims2 + (1,) * (3 - len(pdims2))
|
pdims2 = pdims2 + (1, ) * (3 - len(pdims2))
|
||||||
return pdims1 == pdims2
|
return pdims1 == pdims2
|
|
@ -1,14 +1,15 @@
|
||||||
|
import jax
|
||||||
import pytest
|
import pytest
|
||||||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
||||||
from helpers import MSE, MSRE
|
from helpers import MSE, MSRE
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
|
|
||||||
from jaxdecomp import ShardedArray
|
from jaxdecomp import ShardedArray
|
||||||
|
|
||||||
from jaxpm.distributed import uniform_particles
|
from jaxpm.distributed import uniform_particles
|
||||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||||
from jaxpm.pm import lpt, make_diffrax_ode
|
from jaxpm.pm import lpt, make_diffrax_ode
|
||||||
from jaxpm.utils import power_spectrum
|
from jaxpm.utils import power_spectrum
|
||||||
import jax
|
|
||||||
_TOLERANCE = 1e-4
|
_TOLERANCE = 1e-4
|
||||||
_PM_TOLERANCE = 1e-3
|
_PM_TOLERANCE = 1e-3
|
||||||
|
|
||||||
|
@ -17,7 +18,8 @@ _PM_TOLERANCE = 1e-3
|
||||||
@pytest.mark.parametrize("order", [1, 2])
|
@pytest.mark.parametrize("order", [1, 2])
|
||||||
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
||||||
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
|
fpm_lpt1_field, fpm_lpt2_field, cosmo, order,
|
||||||
|
shardedArrayAPI):
|
||||||
|
|
||||||
mesh_shape, box_shape = simulation_config
|
mesh_shape, box_shape = simulation_config
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
|
@ -53,7 +55,8 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
@pytest.mark.parametrize("order", [1, 2])
|
@pytest.mark.parametrize("order", [1, 2])
|
||||||
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
||||||
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
|
fpm_lpt1_field, fpm_lpt2_field, cosmo, order,
|
||||||
|
shardedArrayAPI):
|
||||||
|
|
||||||
mesh_shape, box_shape = simulation_config
|
mesh_shape, box_shape = simulation_config
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
|
@ -77,12 +80,13 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
||||||
assert type(dx) == ShardedArray
|
assert type(dx) == ShardedArray
|
||||||
assert type(lpt_field) == ShardedArray
|
assert type(lpt_field) == ShardedArray
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.single_device
|
@pytest.mark.single_device
|
||||||
@pytest.mark.parametrize("order", [1, 2])
|
@pytest.mark.parametrize("order", [1, 2])
|
||||||
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
||||||
def test_nbody_absolute(simulation_config, initial_conditions,
|
def test_nbody_absolute(simulation_config, initial_conditions,
|
||||||
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
||||||
cosmo, order , shardedArrayAPI):
|
cosmo, order, shardedArrayAPI):
|
||||||
|
|
||||||
mesh_shape, box_shape = simulation_config
|
mesh_shape, box_shape = simulation_config
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
|
@ -110,7 +114,8 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
||||||
|
|
||||||
saveat = SaveAt(t1=True)
|
saveat = SaveAt(t1=True)
|
||||||
|
|
||||||
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]), particles , dx, p)
|
y0 = jax.tree.map(lambda particles, dx, p: jnp.stack([particles + dx, p]),
|
||||||
|
particles, dx, p)
|
||||||
|
|
||||||
solutions = diffeqsolve(ode_fn,
|
solutions = diffeqsolve(ode_fn,
|
||||||
solver,
|
solver,
|
||||||
|
@ -135,7 +140,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
||||||
|
|
||||||
if shardedArrayAPI:
|
if shardedArrayAPI:
|
||||||
assert type(dx) == ShardedArray
|
assert type(dx) == ShardedArray
|
||||||
assert type( solutions.ys[-1, 0]) == ShardedArray
|
assert type(solutions.ys[-1, 0]) == ShardedArray
|
||||||
assert type(final_field) == ShardedArray
|
assert type(final_field) == ShardedArray
|
||||||
|
|
||||||
|
|
||||||
|
@ -144,7 +149,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
||||||
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
||||||
def test_nbody_relative(simulation_config, initial_conditions,
|
def test_nbody_relative(simulation_config, initial_conditions,
|
||||||
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
||||||
cosmo, order , shardedArrayAPI):
|
cosmo, order, shardedArrayAPI):
|
||||||
|
|
||||||
mesh_shape, box_shape = simulation_config
|
mesh_shape, box_shape = simulation_config
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
|
@ -155,8 +160,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
||||||
# Initial displacement
|
# Initial displacement
|
||||||
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||||
|
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
|
||||||
|
|
||||||
solver = Dopri5()
|
solver = Dopri5()
|
||||||
controller = PIDController(rtol=1e-9,
|
controller = PIDController(rtol=1e-9,
|
||||||
|
@ -167,7 +171,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
||||||
|
|
||||||
saveat = SaveAt(t1=True)
|
saveat = SaveAt(t1=True)
|
||||||
|
|
||||||
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]), dx, p)
|
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
||||||
|
|
||||||
solutions = diffeqsolve(ode_fn,
|
solutions = diffeqsolve(ode_fn,
|
||||||
solver,
|
solver,
|
||||||
|
@ -192,5 +196,5 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
||||||
|
|
||||||
if shardedArrayAPI:
|
if shardedArrayAPI:
|
||||||
assert type(dx) == ShardedArray
|
assert type(dx) == ShardedArray
|
||||||
assert type( solutions.ys[-1, 0]) == ShardedArray
|
assert type(solutions.ys[-1, 0]) == ShardedArray
|
||||||
assert type(final_field) == ShardedArray
|
assert type(final_field) == ShardedArray
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
from conftest import initialize_distributed , compare_sharding
|
from conftest import compare_sharding, initialize_distributed
|
||||||
|
|
||||||
initialize_distributed() # ignore : E402
|
initialize_distributed() # ignore : E402
|
||||||
|
|
||||||
|
from functools import partial # noqa : E402
|
||||||
|
|
||||||
import jax # noqa : E402
|
import jax # noqa : E402
|
||||||
import jax.numpy as jnp # noqa : E402
|
import jax.numpy as jnp # noqa : E402
|
||||||
|
import jax_cosmo as jc # noqa : E402
|
||||||
import pytest # noqa : E402
|
import pytest # noqa : E402
|
||||||
from diffrax import SaveAt # noqa : E402
|
from diffrax import SaveAt # noqa : E402
|
||||||
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
|
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
|
||||||
|
@ -12,13 +15,13 @@ from jax import lax # noqa : E402
|
||||||
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
||||||
from jax.sharding import NamedSharding
|
from jax.sharding import NamedSharding
|
||||||
from jax.sharding import PartitionSpec as P # noqa : E402
|
from jax.sharding import PartitionSpec as P # noqa : E402
|
||||||
from jaxpm.pm import pm_forces # noqa : E402
|
|
||||||
from jaxpm.distributed import uniform_particles , fft3d # noqa : E402
|
|
||||||
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
|
||||||
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
|
||||||
from jaxdecomp import ShardedArray # noqa : E402
|
from jaxdecomp import ShardedArray # noqa : E402
|
||||||
from functools import partial # noqa : E402
|
|
||||||
import jax_cosmo as jc # noqa : E402
|
from jaxpm.distributed import fft3d, uniform_particles # noqa : E402
|
||||||
|
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||||
|
from jaxpm.pm import pm_forces # noqa : E402
|
||||||
|
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
||||||
|
|
||||||
_TOLERANCE = 3.0 # 🙃🙃
|
_TOLERANCE = 3.0 # 🙃🙃
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,7 +30,7 @@ _TOLERANCE = 3.0 # 🙃🙃
|
||||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||||
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
||||||
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||||
absolute_painting,shardedArrayAPI):
|
absolute_painting, shardedArrayAPI):
|
||||||
|
|
||||||
mesh_shape, box_shape = simulation_config
|
mesh_shape, box_shape = simulation_config
|
||||||
# SINGLE DEVICE RUN
|
# SINGLE DEVICE RUN
|
||||||
|
@ -42,18 +45,16 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||||
if shardedArrayAPI:
|
if shardedArrayAPI:
|
||||||
particles = ShardedArray(particles)
|
particles = ShardedArray(particles)
|
||||||
# Initial displacement
|
# Initial displacement
|
||||||
dx, p, _ = lpt(cosmo,
|
dx, p, _ = lpt(cosmo, ic, particles, a=0.1, order=order)
|
||||||
ic,
|
|
||||||
particles,
|
|
||||||
a=0.1,
|
|
||||||
order=order)
|
|
||||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||||
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
|
y0 = jax.tree.map(
|
||||||
|
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
|
||||||
|
dx, p)
|
||||||
else:
|
else:
|
||||||
dx, p, _ = lpt(cosmo, ic, a=0.1, order=order)
|
dx, p, _ = lpt(cosmo, ic, a=0.1, order=order)
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
|
||||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
paint_absolute_pos=False))
|
||||||
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
|
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
||||||
|
|
||||||
solver = Dopri5()
|
solver = Dopri5()
|
||||||
controller = PIDController(rtol=1e-8,
|
controller = PIDController(rtol=1e-8,
|
||||||
|
@ -87,13 +88,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||||
halo_size = mesh_shape[0] // 2
|
halo_size = mesh_shape[0] // 2
|
||||||
|
|
||||||
ic = lax.with_sharding_constraint(initial_conditions,
|
ic = lax.with_sharding_constraint(initial_conditions, sharding)
|
||||||
sharding)
|
|
||||||
|
|
||||||
print(f"sharded initial conditions {ic.sharding}")
|
print(f"sharded initial conditions {ic.sharding}")
|
||||||
|
|
||||||
if shardedArrayAPI:
|
if shardedArrayAPI:
|
||||||
ic = ShardedArray(ic , sharding)
|
ic = ShardedArray(ic, sharding)
|
||||||
|
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
if absolute_painting:
|
if absolute_painting:
|
||||||
|
@ -110,12 +110,13 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
|
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(
|
||||||
make_diffrax_ode(
|
make_diffrax_ode(mesh_shape,
|
||||||
mesh_shape,
|
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding))
|
sharding=sharding))
|
||||||
|
|
||||||
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
|
y0 = jax.tree.map(
|
||||||
|
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
|
||||||
|
dx, p)
|
||||||
else:
|
else:
|
||||||
dx, p, _ = lpt(cosmo,
|
dx, p, _ = lpt(cosmo,
|
||||||
ic,
|
ic,
|
||||||
|
@ -124,12 +125,11 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(
|
||||||
make_diffrax_ode(
|
make_diffrax_ode(mesh_shape,
|
||||||
mesh_shape,
|
|
||||||
paint_absolute_pos=False,
|
paint_absolute_pos=False,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding))
|
sharding=sharding))
|
||||||
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
|
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
||||||
|
|
||||||
solver = Dopri5()
|
solver = Dopri5()
|
||||||
controller = PIDController(rtol=1e-8,
|
controller = PIDController(rtol=1e-8,
|
||||||
|
@ -170,17 +170,18 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||||
|
|
||||||
if shardedArrayAPI:
|
if shardedArrayAPI:
|
||||||
assert type(multi_device_final_field) == ShardedArray
|
assert type(multi_device_final_field) == ShardedArray
|
||||||
assert compare_sharding(multi_device_final_field.sharding , sharding)
|
assert compare_sharding(multi_device_final_field.sharding, sharding)
|
||||||
assert compare_sharding(multi_device_final_field.initial_sharding , sharding)
|
assert compare_sharding(multi_device_final_field.initial_sharding,
|
||||||
|
sharding)
|
||||||
|
|
||||||
assert mse < _TOLERANCE
|
assert mse < _TOLERANCE
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.distributed
|
@pytest.mark.distributed
|
||||||
@pytest.mark.parametrize("order", [1, 2])
|
@pytest.mark.parametrize("order", [1, 2])
|
||||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||||
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, order,nbody_from_lpt1, nbody_from_lpt2,
|
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
|
||||||
|
order, nbody_from_lpt1, nbody_from_lpt2,
|
||||||
absolute_painting):
|
absolute_painting):
|
||||||
|
|
||||||
mesh_shape, box_shape = simulation_config
|
mesh_shape, box_shape = simulation_config
|
||||||
|
@ -196,14 +197,12 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
|
||||||
|
|
||||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||||
|
|
||||||
|
initial_conditions = ShardedArray(initial_conditions, sharding)
|
||||||
initial_conditions = ShardedArray(initial_conditions , sharding)
|
|
||||||
|
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def forward_model(initial_conditions , cosmo):
|
def forward_model(initial_conditions, cosmo):
|
||||||
|
|
||||||
|
|
||||||
if absolute_painting:
|
if absolute_painting:
|
||||||
particles = uniform_particles(mesh_shape, sharding=sharding)
|
particles = uniform_particles(mesh_shape, sharding=sharding)
|
||||||
|
@ -218,12 +217,13 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
|
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(
|
||||||
make_diffrax_ode(
|
make_diffrax_ode(mesh_shape,
|
||||||
mesh_shape,
|
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding))
|
sharding=sharding))
|
||||||
|
|
||||||
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
|
y0 = jax.tree.map(
|
||||||
|
lambda particles, dx, p: jnp.stack([particles + dx, p]),
|
||||||
|
particles, dx, p)
|
||||||
else:
|
else:
|
||||||
dx, p, _ = lpt(cosmo,
|
dx, p, _ = lpt(cosmo,
|
||||||
initial_conditions,
|
initial_conditions,
|
||||||
|
@ -232,12 +232,11 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(
|
||||||
make_diffrax_ode(
|
make_diffrax_ode(mesh_shape,
|
||||||
mesh_shape,
|
|
||||||
paint_absolute_pos=False,
|
paint_absolute_pos=False,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding))
|
sharding=sharding))
|
||||||
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
|
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
||||||
|
|
||||||
solver = Dopri5()
|
solver = Dopri5()
|
||||||
controller = PIDController(rtol=1e-8,
|
controller = PIDController(rtol=1e-8,
|
||||||
|
@ -271,30 +270,31 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
|
||||||
return multi_device_final_field
|
return multi_device_final_field
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def model(initial_conditions , cosmo):
|
def model(initial_conditions, cosmo):
|
||||||
|
|
||||||
final_field = forward_model(initial_conditions , cosmo)
|
final_field = forward_model(initial_conditions, cosmo)
|
||||||
final_field, = jax.tree.leaves(final_field)
|
final_field, = jax.tree.leaves(final_field)
|
||||||
|
|
||||||
return MSE(final_field,
|
return MSE(final_field,
|
||||||
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
|
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
|
||||||
|
|
||||||
obs_val = model(initial_conditions , cosmo)
|
obs_val = model(initial_conditions, cosmo)
|
||||||
|
|
||||||
shifted_initial_conditions = initial_conditions + jax.random.normal(jax.random.key(42) , initial_conditions.shape) * 5
|
shifted_initial_conditions = initial_conditions + jax.random.normal(
|
||||||
|
jax.random.key(42), initial_conditions.shape) * 5
|
||||||
|
|
||||||
good_grads = jax.grad(model)(initial_conditions , cosmo)
|
good_grads = jax.grad(model)(initial_conditions, cosmo)
|
||||||
off_grads = jax.grad(model)(shifted_initial_conditions , cosmo)
|
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
|
||||||
|
|
||||||
assert compare_sharding(good_grads.sharding , initial_conditions.sharding)
|
assert compare_sharding(good_grads.sharding, initial_conditions.sharding)
|
||||||
assert compare_sharding(off_grads.sharding , initial_conditions.sharding)
|
assert compare_sharding(off_grads.sharding, initial_conditions.sharding)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.distributed
|
@pytest.mark.distributed
|
||||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||||
def test_fwd_rev_gradients(cosmo,absolute_painting):
|
def test_fwd_rev_gradients(cosmo, absolute_painting):
|
||||||
|
|
||||||
mesh_shape, box_shape = (8 , 8 , 8) , (20.0 , 20.0 , 20.0)
|
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||||
# SINGLE DEVICE RUN
|
# SINGLE DEVICE RUN
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
|
|
||||||
|
@ -308,17 +308,23 @@ def test_fwd_rev_gradients(cosmo,absolute_painting):
|
||||||
sharding)
|
sharding)
|
||||||
|
|
||||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||||
initial_conditions = ShardedArray(initial_conditions , sharding)
|
initial_conditions = ShardedArray(initial_conditions, sharding)
|
||||||
|
|
||||||
cosmo._workspace = {}
|
cosmo._workspace = {}
|
||||||
|
|
||||||
@partial(jax.jit , static_argnums=(3,4 , 5))
|
@partial(jax.jit, static_argnums=(3, 4, 5))
|
||||||
def compute_forces(initial_conditions , cosmo , particles=None , a=0.5 , halo_size=0 , sharding=None):
|
def compute_forces(initial_conditions,
|
||||||
|
cosmo,
|
||||||
|
particles=None,
|
||||||
|
a=0.5,
|
||||||
|
halo_size=0,
|
||||||
|
sharding=None):
|
||||||
|
|
||||||
paint_absolute_pos = particles is not None
|
paint_absolute_pos = particles is not None
|
||||||
if particles is None:
|
if particles is None:
|
||||||
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
|
particles = jax.tree.map(
|
||||||
shape=(*ic.shape, 3)) , initial_conditions)
|
lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
|
||||||
|
initial_conditions)
|
||||||
|
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||||
|
@ -329,13 +335,27 @@ def test_fwd_rev_gradients(cosmo,absolute_painting):
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
|
|
||||||
return initial_force[...,0]
|
return initial_force[..., 0]
|
||||||
|
|
||||||
particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding) , sharding) if absolute_painting else None
|
particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding),
|
||||||
forces = compute_forces(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
|
sharding) if absolute_painting else None
|
||||||
back_gradient = jax.jacrev(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
|
forces = compute_forces(initial_conditions,
|
||||||
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
|
cosmo,
|
||||||
|
particles=particles,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
|
||||||
|
cosmo,
|
||||||
|
particles=particles,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
|
||||||
|
cosmo,
|
||||||
|
particles=particles,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
|
||||||
assert compare_sharding(forces.sharding , initial_conditions.sharding)
|
assert compare_sharding(forces.sharding, initial_conditions.sharding)
|
||||||
assert compare_sharding(back_gradient[0,0,0,...].sharding , initial_conditions.sharding)
|
assert compare_sharding(back_gradient[0, 0, 0, ...].sharding,
|
||||||
assert compare_sharding(fwd_gradient.sharding , initial_conditions.sharding)
|
initial_conditions.sharding)
|
||||||
|
assert compare_sharding(fwd_gradient.sharding, initial_conditions.sharding)
|
||||||
|
|
|
@ -1,31 +1,31 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||||
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["EQX_ON_ERROR"] = "nan"
|
os.environ["EQX_ON_ERROR"] = "nan"
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
|
from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt,
|
||||||
|
diffeqsolve)
|
||||||
from jax.debug import visualize_array_sharding
|
from jax.debug import visualize_array_sharding
|
||||||
|
|
||||||
from jaxpm.kernels import interpolate_power_spectrum
|
|
||||||
from jaxpm.painting import cic_paint_dx , cic_read_dx , cic_paint , cic_read
|
|
||||||
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
|
||||||
from functools import partial
|
|
||||||
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
|
|
||||||
from jaxpm.distributed import uniform_particles
|
|
||||||
|
|
||||||
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from jax.experimental.mesh_utils import create_device_mesh
|
from jax.experimental.mesh_utils import create_device_mesh
|
||||||
from jax.experimental.multihost_utils import process_allgather
|
from jax.experimental.multihost_utils import process_allgather
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import Mesh, NamedSharding
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
|
from jaxpm.distributed import uniform_particles
|
||||||
|
from jaxpm.kernels import interpolate_power_spectrum
|
||||||
|
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||||
|
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
||||||
|
|
||||||
|
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
|
||||||
|
|
||||||
|
|
||||||
all_gather = partial(process_allgather, tiled=False)
|
all_gather = partial(process_allgather, tiled=False)
|
||||||
|
|
||||||
pdims = (2, 4)
|
pdims = (2, 4)
|
||||||
|
@ -34,8 +34,8 @@ pdims = (2, 4)
|
||||||
#sharding = NamedSharding(mesh, P('x', 'y'))
|
#sharding = NamedSharding(mesh, P('x', 'y'))
|
||||||
sharding = None
|
sharding = None
|
||||||
|
|
||||||
|
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
from jaxdecomp import ShardedArray
|
from jaxdecomp import ShardedArray
|
||||||
|
|
||||||
mesh_shape = 64
|
mesh_shape = 64
|
||||||
|
@ -43,19 +43,21 @@ box_size = 64.
|
||||||
halo_size = 2
|
halo_size = 2
|
||||||
snapshots = (0.5, 1.0)
|
snapshots = (0.5, 1.0)
|
||||||
|
|
||||||
|
|
||||||
class Params(NamedTuple):
|
class Params(NamedTuple):
|
||||||
omega_c: float
|
omega_c: float
|
||||||
sigma8: float
|
sigma8: float
|
||||||
initial_conditions : jnp.ndarray
|
initial_conditions: jnp.ndarray
|
||||||
|
|
||||||
mesh_shape = (mesh_shape,) * 3
|
|
||||||
box_size = (box_size,) * 3
|
mesh_shape = (mesh_shape, ) * 3
|
||||||
|
box_size = (box_size, ) * 3
|
||||||
omega_c = 0.25
|
omega_c = 0.25
|
||||||
sigma8 = 0.8
|
sigma8 = 0.8
|
||||||
# Create a small function to generate the matter power spectrum
|
# Create a small function to generate the matter power spectrum
|
||||||
k = jnp.logspace(-4, 1, 128)
|
k = jnp.logspace(-4, 1, 128)
|
||||||
pk = jc.power.linear_matter_power(
|
pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8),
|
||||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
k)
|
||||||
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
||||||
|
|
||||||
initial_conditions = linear_field(mesh_shape,
|
initial_conditions = linear_field(mesh_shape,
|
||||||
|
@ -64,21 +66,19 @@ initial_conditions = linear_field(mesh_shape,
|
||||||
seed=jax.random.PRNGKey(0),
|
seed=jax.random.PRNGKey(0),
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
|
|
||||||
|
|
||||||
#initial_conditions = ShardedArray(initial_conditions, sharding)
|
#initial_conditions = ShardedArray(initial_conditions, sharding)
|
||||||
|
|
||||||
params = Params(omega_c, sigma8, initial_conditions)
|
params = Params(omega_c, sigma8, initial_conditions)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
|
||||||
@partial(jax.jit , static_argnums=(1 , 2,3,4 ))
|
def forward_model(params, mesh_shape, box_size, halo_size, snapshots):
|
||||||
def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
|
|
||||||
|
|
||||||
# Create initial conditions
|
# Create initial conditions
|
||||||
cosmo = jc.Planck15(Omega_c=params.omega_c, sigma8=params.sigma8)
|
cosmo = jc.Planck15(Omega_c=params.omega_c, sigma8=params.sigma8)
|
||||||
particles = uniform_particles(mesh_shape , sharding)
|
particles = uniform_particles(mesh_shape, sharding)
|
||||||
ic_structure = jax.tree.structure(params.initial_conditions)
|
ic_structure = jax.tree.structure(params.initial_conditions)
|
||||||
particles = jax.tree.unflatten(ic_structure , jax.tree.leaves(particles))
|
particles = jax.tree.unflatten(ic_structure, jax.tree.leaves(particles))
|
||||||
# Initial displacement
|
# Initial displacement
|
||||||
dx, p, f = lpt(cosmo,
|
dx, p, f = lpt(cosmo,
|
||||||
params.initial_conditions,
|
params.initial_conditions,
|
||||||
|
@ -90,10 +90,15 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
|
||||||
|
|
||||||
# Evolve the simulation forward
|
# Evolve the simulation forward
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(
|
||||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=True,halo_size=halo_size,sharding=sharding))
|
make_diffrax_ode(mesh_shape,
|
||||||
|
paint_absolute_pos=True,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding))
|
||||||
solver = LeapfrogMidpoint()
|
solver = LeapfrogMidpoint()
|
||||||
|
|
||||||
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx ,p],axis=0) , particles , dx , p)
|
y0 = jax.tree.map(
|
||||||
|
lambda particles, dx, p: jnp.stack([particles + dx, p], axis=0),
|
||||||
|
particles, dx, p)
|
||||||
print(f"y0 structure: {jax.tree.structure(y0)}")
|
print(f"y0 structure: {jax.tree.structure(y0)}")
|
||||||
|
|
||||||
stepsize_controller = ConstantStepSize()
|
stepsize_controller = ConstantStepSize()
|
||||||
|
@ -108,17 +113,16 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
|
||||||
stepsize_controller=stepsize_controller)
|
stepsize_controller=stepsize_controller)
|
||||||
ode_solutions = [sol[0] for sol in res.ys]
|
ode_solutions = [sol[0] for sol in res.ys]
|
||||||
|
|
||||||
ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), ode_solutions[-1])
|
ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32),
|
||||||
return particles + dx , ode_field
|
ode_solutions[-1])
|
||||||
|
return particles + dx, ode_field
|
||||||
|
|
||||||
ode_field = cic_paint_dx(ode_solutions[-1])
|
ode_field = cic_paint_dx(ode_solutions[-1])
|
||||||
return dx , ode_field
|
return dx, ode_field
|
||||||
|
|
||||||
|
|
||||||
|
lpt_particles, ode_field = forward_model(params, mesh_shape, box_size,
|
||||||
lpt_particles , ode_field = forward_model(params , mesh_shape,box_size,halo_size , snapshots)
|
halo_size, snapshots)
|
||||||
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
@ -127,11 +131,11 @@ lpt_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), lpt_particles)
|
||||||
|
|
||||||
plt.figure(figsize=(12, 6))
|
plt.figure(figsize=(12, 6))
|
||||||
plt.subplot(121)
|
plt.subplot(121)
|
||||||
plt.imshow(lpt_field.sum(axis=0) , cmap='magma')
|
plt.imshow(lpt_field.sum(axis=0), cmap='magma')
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title('LPT field')
|
plt.title('LPT field')
|
||||||
plt.subplot(122)
|
plt.subplot(122)
|
||||||
plt.imshow(ode_field.sum(axis=0) , cmap='magma')
|
plt.imshow(ode_field.sum(axis=0), cmap='magma')
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title('ODE field')
|
plt.title('ODE field')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
Loading…
Add table
Reference in a new issue