From e1a8134b8e51ed9ba5b98848c32f0cf55d2bb9be Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Thu, 12 Jun 2025 14:46:27 +0200 Subject: [PATCH 1/5] update by default normal_field dtype to match JAX --- jaxpm/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 86cd816..2d361f6 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None): axis=-1) -def normal_field(seed , shape, sharding=None, dtype='float32'): +def normal_field(seed , shape, sharding=None, dtype=float): """Generate a Gaussian random field with the given power spectrum.""" gpu_mesh = sharding.mesh if sharding is not None else None if gpu_mesh is not None and not (gpu_mesh.empty): From 67a80e1041bd88076ab1e4a950720a1cb4b59ba1 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Thu, 12 Jun 2025 14:51:24 +0200 Subject: [PATCH 2/5] format --- jaxpm/distributed.py | 2 +- jaxpm/pm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 2d361f6..24adc6c 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None): axis=-1) -def normal_field(seed , shape, sharding=None, dtype=float): +def normal_field(seed, shape, sharding=None, dtype=float): """Generate a Gaussian random field with the given power spectrum.""" gpu_mesh = sharding.mesh if sharding is not None else None if gpu_mesh is not None and not (gpu_mesh.empty): diff --git a/jaxpm/pm.py b/jaxpm/pm.py index ae7db3f..95dce20 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None): Generate initial conditions. """ # Initialize a random field with one slice on each gpu - field = normal_field(seed=seed , shape=mesh_shape, sharding=sharding) + field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding) field = fft3d(field) kvec = fftk(field) kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 From 187cf5c4ba68a48b2ca2b41f75b63898ece80042 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Thu, 12 Jun 2025 15:19:36 +0200 Subject: [PATCH 3/5] debug test workflow --- requirements-test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-test.txt b/requirements-test.txt index aa9de82..c6b2e46 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,4 @@ pfft-python @ git+https://github.com/MP-Gadget/pfft-python pmesh @ git+https://github.com/MP-Gadget/pmesh fastpm @ git+https://github.com/ASKabalan/fastpm-python +numpy \ No newline at end of file From d874790543dfe4ab9087ec8d0a0256aa9ce17b94 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Thu, 12 Jun 2025 15:20:20 +0200 Subject: [PATCH 4/5] format --- .github/workflows/tests.yml | 2 ++ requirements-test.txt | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 89b760b..041d897 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,6 +35,8 @@ jobs: pip install pytest pip install diffrax pip install . + echo "numpy version installed:" + python -c "import numpy; print(numpy.__version__)" - name: Run Single Device Tests run: | diff --git a/requirements-test.txt b/requirements-test.txt index c6b2e46..e68f835 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,4 @@ pfft-python @ git+https://github.com/MP-Gadget/pfft-python pmesh @ git+https://github.com/MP-Gadget/pmesh fastpm @ git+https://github.com/ASKabalan/fastpm-python -numpy \ No newline at end of file +numpy From 7623e60581f274a6ae71bbc3701bb51c0398d921 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Thu, 12 Jun 2025 15:28:03 +0200 Subject: [PATCH 5/5] debug test workflow --- requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index e68f835..70a7229 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,4 @@ pfft-python @ git+https://github.com/MP-Gadget/pfft-python pmesh @ git+https://github.com/MP-Gadget/pmesh fastpm @ git+https://github.com/ASKabalan/fastpm-python -numpy +numpy==2.2.6