diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 89b760b..041d897 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,6 +35,8 @@ jobs: pip install pytest pip install diffrax pip install . + echo "numpy version installed:" + python -c "import numpy; print(numpy.__version__)" - name: Run Single Device Tests run: | diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 86cd816..24adc6c 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None): axis=-1) -def normal_field(seed , shape, sharding=None, dtype='float32'): +def normal_field(seed, shape, sharding=None, dtype=float): """Generate a Gaussian random field with the given power spectrum.""" gpu_mesh = sharding.mesh if sharding is not None else None if gpu_mesh is not None and not (gpu_mesh.empty): diff --git a/jaxpm/pm.py b/jaxpm/pm.py index ae7db3f..95dce20 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None): Generate initial conditions. """ # Initialize a random field with one slice on each gpu - field = normal_field(seed=seed , shape=mesh_shape, sharding=sharding) + field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding) field = fft3d(field) kvec = fftk(field) kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 diff --git a/requirements-test.txt b/requirements-test.txt index aa9de82..70a7229 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,4 @@ pfft-python @ git+https://github.com/MP-Gadget/pfft-python pmesh @ git+https://github.com/MP-Gadget/pmesh fastpm @ git+https://github.com/ASKabalan/fastpm-python +numpy==2.2.6