diff --git a/jaxpm/growth.py b/jaxpm/growth.py index e4d3815..f1392b6 100644 --- a/jaxpm/growth.py +++ b/jaxpm/growth.py @@ -28,7 +28,7 @@ def E(cosmo, a): where :math:`f(a)` is the Dark Energy evolution parameter computed by :py:meth:`.f_de`. """ - return np.power(Esqr(cosmo, a), 0.5) + return np.sqrt(Esqr(cosmo, a)) def df_de(cosmo, a, epsilon=1e-5): diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 95dce20..7a307a6 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -139,7 +139,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None): pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / ( box_size[0] * box_size[1] * box_size[2]) - field = field * (pkmesh)**0.5 + field = field * jnp.sqrt(pkmesh) field = ifft3d(field) return field diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 96faeea..01898c5 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -52,7 +52,7 @@ def _initialize_pk(mesh_shape, box_shape, kedges, los): kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1 kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape) for m, l, kshape in zip(mesh_shape, box_shape, kshapes)] - kmesh = sum(ki**2 for ki in kvec)**0.5 + kmesh = jnp.sqrt(sum(ki**2 for ki in kvec)) dig = np.digitize(kmesh.reshape(-1), kedges) kcount = np.bincount(dig, minlength=len(kedges) + 1) diff --git a/tests/conftest.py b/tests/conftest.py index 2a43119..5871dae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,8 +96,9 @@ def fpm_initial_conditions(cosmo, particle_mesh): whitec = particle_mesh.generate_whitenoise(42, type='complex', unitary=False) - lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5 - * v * (1 / v.BoxSize).prod()**0.5) + lineark = whitec.apply(lambda k, v: jnp.sqrt( + pk_fn(jnp.sqrt(sum(ki**2 for ki in k)))) * v * jnp.sqrt( + (1 / v.BoxSize).prod())) init_mesh = lineark.c2r().value # XXX return lineark, grid, init_mesh