mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
apply formatting
This commit is contained in:
parent
11f7e90066
commit
4342279817
5 changed files with 26 additions and 22 deletions
|
@ -40,7 +40,7 @@ def ifft3d(x):
|
|||
def get_halo_size(halo_size, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
zero_ext = (0, 0, 0)
|
||||
zero_ext = (0, 0)
|
||||
zero_tuple = (0, 0)
|
||||
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
|
||||
else:
|
||||
|
|
|
@ -5,8 +5,8 @@ import jax.lax as lax
|
|||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange,
|
||||
slice_pad, slice_unpad, fft3d, ifft3d)
|
||||
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
|
||||
ifft3d, slice_pad, slice_unpad)
|
||||
from jaxpm.kernels import cic_compensation, fftk
|
||||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
|
|
@ -7,21 +7,21 @@ jax.distributed.initialize()
|
|||
rank = jax.process_index()
|
||||
size = jax.process_count()
|
||||
|
||||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
|
||||
import numpy as np
|
||||
from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt,
|
||||
diffeqsolve)
|
||||
from jax.experimental.mesh_utils import create_device_mesh
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
|
||||
all_gather = partial(process_allgather, tiled=True)
|
||||
|
||||
|
@ -84,7 +84,8 @@ def run_simulation(omega_c, sigma8):
|
|||
return initial_conditions, dx, res.ys, res.stats
|
||||
|
||||
|
||||
initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)
|
||||
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
|
||||
0.25, 0.8)
|
||||
print(f"[{rank}] Simulation completed")
|
||||
print(f"[{rank}] Solver stats: {solver_stats}")
|
||||
|
||||
|
@ -95,10 +96,9 @@ ode_solutions = [all_gather(sol) for sol in ode_solutions]
|
|||
|
||||
if rank == 0:
|
||||
np.savez("multihost_pm.npz",
|
||||
initial_conditions=initial_conditions,
|
||||
lpt_displacements=lpt_displacements,
|
||||
ode_solutions=ode_solutions,
|
||||
solver_stats=solver_stats)
|
||||
initial_conditions=initial_conditions,
|
||||
lpt_displacements=lpt_displacements,
|
||||
ode_solutions=ode_solutions,
|
||||
solver_stats=solver_stats)
|
||||
|
||||
print(f"[{rank}] Simulation results saved")
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def plot_fields(fields_dict, sum_over=None):
|
||||
"""
|
||||
|
@ -23,8 +24,10 @@ def plot_fields(fields_dict, sum_over=None):
|
|||
slicing = tuple(slicing)
|
||||
|
||||
# Sum projection over the specified axis and plot
|
||||
axes[row, proj_axis].imshow(field[slicing].sum(axis=proj_axis) + 1,
|
||||
cmap='magma', extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
|
||||
axes[row, proj_axis].imshow(
|
||||
field[slicing].sum(axis=proj_axis) + 1,
|
||||
cmap='magma',
|
||||
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
|
||||
axes[row, proj_axis].set_xlabel('Mpc/h')
|
||||
axes[row, proj_axis].set_ylabel('Mpc/h')
|
||||
axes[row, proj_axis].set_title(title)
|
||||
|
@ -32,7 +35,8 @@ def plot_fields(fields_dict, sum_over=None):
|
|||
# Plot each field across the three axes
|
||||
for i, (name, field) in enumerate(fields_dict.items()):
|
||||
for proj_axis in range(3):
|
||||
plot_subplots(proj_axis, field, i, f"{name} projection {proj_axis}")
|
||||
plot_subplots(proj_axis, field, i,
|
||||
f"{name} projection {proj_axis}")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
|
2
setup.py
2
setup.py
|
@ -7,5 +7,5 @@ setup(
|
|||
author='JaxPM developers',
|
||||
description='A dead simple FastPM implementation in JAX',
|
||||
packages=find_packages(),
|
||||
install_requires=['jax', 'jax_cosmo','jaxdecomp'],
|
||||
install_requires=['jax', 'jax_cosmo', 'jaxdecomp'],
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue