diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 041d897..89b760b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,8 +35,6 @@ jobs: pip install pytest pip install diffrax pip install . - echo "numpy version installed:" - python -c "import numpy; print(numpy.__version__)" - name: Run Single Device Tests run: | diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 24adc6c..86cd816 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None): axis=-1) -def normal_field(seed, shape, sharding=None, dtype=float): +def normal_field(seed , shape, sharding=None, dtype='float32'): """Generate a Gaussian random field with the given power spectrum.""" gpu_mesh = sharding.mesh if sharding is not None else None if gpu_mesh is not None and not (gpu_mesh.empty): diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 95dce20..ae7db3f 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None): Generate initial conditions. """ # Initialize a random field with one slice on each gpu - field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding) + field = normal_field(seed=seed , shape=mesh_shape, sharding=sharding) field = fft3d(field) kvec = fftk(field) kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 diff --git a/requirements-test.txt b/requirements-test.txt index 70a7229..aa9de82 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,3 @@ pfft-python @ git+https://github.com/MP-Gadget/pfft-python pmesh @ git+https://github.com/MP-Gadget/pmesh fastpm @ git+https://github.com/ASKabalan/fastpm-python -numpy==2.2.6