apply formatting

This commit is contained in:
Wassim KABALAN 2024-10-27 00:52:14 +02:00
parent 11f7e90066
commit 4342279817
5 changed files with 26 additions and 22 deletions

View file

@ -40,7 +40,7 @@ def ifft3d(x):
def get_halo_size(halo_size, sharding): def get_halo_size(halo_size, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is None or gpu_mesh.empty: if gpu_mesh is None or gpu_mesh.empty:
zero_ext = (0, 0, 0) zero_ext = (0, 0)
zero_tuple = (0, 0) zero_tuple = (0, 0)
return (zero_tuple, zero_tuple, zero_tuple), zero_ext return (zero_tuple, zero_tuple, zero_tuple), zero_ext
else: else:

View file

@ -5,8 +5,8 @@ import jax.lax as lax
import jax.numpy as jnp import jax.numpy as jnp
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange, from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
slice_pad, slice_unpad, fft3d, ifft3d) ifft3d, slice_pad, slice_unpad)
from jaxpm.kernels import cic_compensation, fftk from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter from jaxpm.painting_utils import gather, scatter

View file

@ -7,21 +7,21 @@ jax.distributed.initialize()
rank = jax.process_index() rank = jax.process_index()
size = jax.process_count() size = jax.process_count()
from functools import partial
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
import numpy as np
from jaxpm.kernels import interpolate_power_spectrum from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt,
from jaxpm.painting import cic_paint_dx diffeqsolve)
from jaxpm.pm import linear_field, lpt, make_ode_fn
from jax.experimental.mesh_utils import create_device_mesh from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from functools import partial
import numpy as np
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
all_gather = partial(process_allgather, tiled=True) all_gather = partial(process_allgather, tiled=True)
@ -84,7 +84,8 @@ def run_simulation(omega_c, sigma8):
return initial_conditions, dx, res.ys, res.stats return initial_conditions, dx, res.ys, res.stats
initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8) initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
0.25, 0.8)
print(f"[{rank}] Simulation completed") print(f"[{rank}] Simulation completed")
print(f"[{rank}] Solver stats: {solver_stats}") print(f"[{rank}] Solver stats: {solver_stats}")
@ -101,4 +102,3 @@ if rank == 0:
solver_stats=solver_stats) solver_stats=solver_stats)
print(f"[{rank}] Simulation results saved") print(f"[{rank}] Simulation results saved")

View file

@ -1,6 +1,7 @@
import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
def plot_fields(fields_dict, sum_over=None): def plot_fields(fields_dict, sum_over=None):
""" """
@ -23,8 +24,10 @@ def plot_fields(fields_dict, sum_over=None):
slicing = tuple(slicing) slicing = tuple(slicing)
# Sum projection over the specified axis and plot # Sum projection over the specified axis and plot
axes[row, proj_axis].imshow(field[slicing].sum(axis=proj_axis) + 1, axes[row, proj_axis].imshow(
cmap='magma', extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]]) field[slicing].sum(axis=proj_axis) + 1,
cmap='magma',
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
axes[row, proj_axis].set_xlabel('Mpc/h') axes[row, proj_axis].set_xlabel('Mpc/h')
axes[row, proj_axis].set_ylabel('Mpc/h') axes[row, proj_axis].set_ylabel('Mpc/h')
axes[row, proj_axis].set_title(title) axes[row, proj_axis].set_title(title)
@ -32,7 +35,8 @@ def plot_fields(fields_dict, sum_over=None):
# Plot each field across the three axes # Plot each field across the three axes
for i, (name, field) in enumerate(fields_dict.items()): for i, (name, field) in enumerate(fields_dict.items()):
for proj_axis in range(3): for proj_axis in range(3):
plot_subplots(proj_axis, field, i, f"{name} projection {proj_axis}") plot_subplots(proj_axis, field, i,
f"{name} projection {proj_axis}")
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()

View file

@ -7,5 +7,5 @@ setup(
author='JaxPM developers', author='JaxPM developers',
description='A dead simple FastPM implementation in JAX', description='A dead simple FastPM implementation in JAX',
packages=find_packages(), packages=find_packages(),
install_requires=['jax', 'jax_cosmo','jaxdecomp'], install_requires=['jax', 'jax_cosmo', 'jaxdecomp'],
) )