mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
apply formating
This commit is contained in:
parent
c93894f561
commit
19011d0712
5 changed files with 22 additions and 15 deletions
|
@ -117,9 +117,11 @@ def get_local_shape(mesh_shape, sharding):
|
||||||
else:
|
else:
|
||||||
pdims = gpu_mesh.devices.shape
|
pdims = gpu_mesh.devices.shape
|
||||||
return [
|
return [
|
||||||
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], *mesh_shape[2:]
|
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1],
|
||||||
|
*mesh_shape[2:]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def zeros(mesh_shape, sharding):
|
def zeros(mesh_shape, sharding):
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||||
|
@ -132,6 +134,7 @@ def zeros(mesh_shape , sharding):
|
||||||
else:
|
else:
|
||||||
return jnp.zeros(mesh_shape)
|
return jnp.zeros(mesh_shape)
|
||||||
|
|
||||||
|
|
||||||
def normal_field(mesh_shape, seed, sharding):
|
def normal_field(mesh_shape, seed, sharding):
|
||||||
"""Generate a Gaussian random field with the given power spectrum."""
|
"""Generate a Gaussian random field with the given power spectrum."""
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
|
|
@ -588,4 +588,5 @@ def dGf2a(cosmo, a):
|
||||||
f2p = cache['h2'] / cache['a'] * cache['g2']
|
f2p = cache['h2'] / cache['a'] * cache['g2']
|
||||||
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
||||||
E_a = E(cosmo, a)
|
E_a = E(cosmo, a)
|
||||||
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E_a * D2f)
|
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
|
||||||
|
3 * a**2 * E_a * D2f)
|
||||||
|
|
|
@ -29,8 +29,10 @@ def pm_forces(positions,
|
||||||
mesh_shape = delta.shape
|
mesh_shape = delta.shape
|
||||||
|
|
||||||
if paint_particles:
|
if paint_particles:
|
||||||
paint_fn = lambda x: cic_paint(
|
paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding),
|
||||||
zeros(mesh_shape,sharding), x , halo_size=halo_size, sharding=sharding)
|
x,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
read_fn = lambda x: cic_read(
|
read_fn = lambda x: cic_read(
|
||||||
x, positions, halo_size=halo_size, sharding=sharding)
|
x, positions, halo_size=halo_size, sharding=sharding)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -35,6 +35,7 @@ box_size = [500., 500., 1000.]
|
||||||
halo_size = 64
|
halo_size = 64
|
||||||
snapshots = jnp.linspace(0.1, 1., 2)
|
snapshots = jnp.linspace(0.1, 1., 2)
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def run_simulation(omega_c, sigma8):
|
def run_simulation(omega_c, sigma8):
|
||||||
# Create a small function to generate the matter power spectrum
|
# Create a small function to generate the matter power spectrum
|
||||||
|
@ -89,9 +90,11 @@ print(f"[{rank}] Solver stats: {solver_stats}")
|
||||||
|
|
||||||
# Gather the results
|
# Gather the results
|
||||||
|
|
||||||
pm_dict = {"initial_conditions": all_gather(initial_conditions),
|
pm_dict = {
|
||||||
|
"initial_conditions": all_gather(initial_conditions),
|
||||||
"lpt_displacements": all_gather(lpt_displacements),
|
"lpt_displacements": all_gather(lpt_displacements),
|
||||||
"solver_stats": solver_stats}
|
"solver_stats": solver_stats
|
||||||
|
}
|
||||||
|
|
||||||
for i in range(len(ode_solutions)):
|
for i in range(len(ode_solutions)):
|
||||||
sol = ode_solutions[i]
|
sol = ode_solutions[i]
|
||||||
|
|
|
@ -62,11 +62,9 @@ def plot_fields_single_projection(fields_dict, sum_over=None):
|
||||||
slicing = tuple(slicing)
|
slicing = tuple(slicing)
|
||||||
|
|
||||||
# Sum projection over axis 0 and plot
|
# Sum projection over axis 0 and plot
|
||||||
axes[i].imshow(
|
axes[i].imshow(field[slicing].sum(axis=0) + 1,
|
||||||
field[slicing].sum(axis=0) + 1,
|
|
||||||
cmap='magma',
|
cmap='magma',
|
||||||
extent=[0, field.shape[1], 0, field.shape[2]]
|
extent=[0, field.shape[1], 0, field.shape[2]])
|
||||||
)
|
|
||||||
axes[i].set_xlabel('Mpc/h')
|
axes[i].set_xlabel('Mpc/h')
|
||||||
axes[i].set_ylabel('Mpc/h')
|
axes[i].set_ylabel('Mpc/h')
|
||||||
axes[i].set_title(f"{name} projection 0")
|
axes[i].set_title(f"{name} projection 0")
|
||||||
|
|
Loading…
Add table
Reference in a new issue