From 67a80e1041bd88076ab1e4a950720a1cb4b59ba1 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Thu, 12 Jun 2025 14:51:24 +0200 Subject: [PATCH] format --- jaxpm/distributed.py | 2 +- jaxpm/pm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 2d361f6..24adc6c 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None): axis=-1) -def normal_field(seed , shape, sharding=None, dtype=float): +def normal_field(seed, shape, sharding=None, dtype=float): """Generate a Gaussian random field with the given power spectrum.""" gpu_mesh = sharding.mesh if sharding is not None else None if gpu_mesh is not None and not (gpu_mesh.empty): diff --git a/jaxpm/pm.py b/jaxpm/pm.py index ae7db3f..95dce20 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None): Generate initial conditions. """ # Initialize a random field with one slice on each gpu - field = normal_field(seed=seed , shape=mesh_shape, sharding=sharding) + field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding) field = fft3d(field) kvec = fftk(field) kmesh = sum((kk / box_size[i] * mesh_shape[i])**2