From 7d7657370172f1496352938f84fe6db7a9b71b40 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Sun, 29 Jun 2025 10:51:45 +0200 Subject: [PATCH] Replace np.power 0.5 by np.sqrt (#43) * Switch **2 by np.sqrt * format --------- Co-authored-by: Francois Lanusse --- jaxpm/growth.py | 2 +- jaxpm/pm.py | 2 +- jaxpm/utils.py | 2 +- tests/conftest.py | 5 +++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/jaxpm/growth.py b/jaxpm/growth.py index ec248f3..3ce5476 100644 --- a/jaxpm/growth.py +++ b/jaxpm/growth.py @@ -26,7 +26,7 @@ def E(cosmo, a): where :math:`f(a)` is the Dark Energy evolution parameter computed by :py:meth:`.f_de`. """ - return np.power(Esqr(cosmo, a), 0.5) + return np.sqrt(Esqr(cosmo, a)) def df_de(cosmo, a, epsilon=1e-5): diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 95dce20..7a307a6 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -139,7 +139,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None): pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / ( box_size[0] * box_size[1] * box_size[2]) - field = field * (pkmesh)**0.5 + field = field * jnp.sqrt(pkmesh) field = ifft3d(field) return field diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 96faeea..01898c5 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -52,7 +52,7 @@ def _initialize_pk(mesh_shape, box_shape, kedges, los): kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1 kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape) for m, l, kshape in zip(mesh_shape, box_shape, kshapes)] - kmesh = sum(ki**2 for ki in kvec)**0.5 + kmesh = jnp.sqrt(sum(ki**2 for ki in kvec)) dig = np.digitize(kmesh.reshape(-1), kedges) kcount = np.bincount(dig, minlength=len(kedges) + 1) diff --git a/tests/conftest.py b/tests/conftest.py index 2a43119..5871dae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,8 +96,9 @@ def fpm_initial_conditions(cosmo, particle_mesh): whitec = particle_mesh.generate_whitenoise(42, type='complex', unitary=False) - lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5 - * v * (1 / v.BoxSize).prod()**0.5) + lineark = whitec.apply(lambda k, v: jnp.sqrt( + pk_fn(jnp.sqrt(sum(ki**2 for ki in k)))) * v * jnp.sqrt( + (1 / v.BoxSize).prod())) init_mesh = lineark.c2r().value # XXX return lineark, grid, init_mesh