Compare commits

..

5 commits

Author SHA1 Message Date
Wassim Kabalan
7623e60581 debug test workflow 2025-06-12 15:28:03 +02:00
Wassim Kabalan
d874790543 format 2025-06-12 15:20:20 +02:00
Wassim Kabalan
187cf5c4ba debug test workflow 2025-06-12 15:19:36 +02:00
Wassim Kabalan
67a80e1041 format 2025-06-12 14:51:24 +02:00
Wassim Kabalan
e1a8134b8e update by default normal_field dtype to match JAX 2025-06-12 14:46:27 +02:00
4 changed files with 5 additions and 2 deletions

View file

@ -35,6 +35,8 @@ jobs:
pip install pytest pip install pytest
pip install diffrax pip install diffrax
pip install . pip install .
echo "numpy version installed:"
python -c "import numpy; print(numpy.__version__)"
- name: Run Single Device Tests - name: Run Single Device Tests
run: | run: |

View file

@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None):
axis=-1) axis=-1)
def normal_field(seed , shape, sharding=None, dtype='float32'): def normal_field(seed, shape, sharding=None, dtype=float):
"""Generate a Gaussian random field with the given power spectrum.""" """Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty): if gpu_mesh is not None and not (gpu_mesh.empty):

View file

@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
Generate initial conditions. Generate initial conditions.
""" """
# Initialize a random field with one slice on each gpu # Initialize a random field with one slice on each gpu
field = normal_field(seed=seed , shape=mesh_shape, sharding=sharding) field = normal_field(seed=seed, shape=mesh_shape, sharding=sharding)
field = fft3d(field) field = fft3d(field)
kvec = fftk(field) kvec = fftk(field)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 kmesh = sum((kk / box_size[i] * mesh_shape[i])**2

View file

@ -1,3 +1,4 @@
pfft-python @ git+https://github.com/MP-Gadget/pfft-python pfft-python @ git+https://github.com/MP-Gadget/pfft-python
pmesh @ git+https://github.com/MP-Gadget/pmesh pmesh @ git+https://github.com/MP-Gadget/pmesh
fastpm @ git+https://github.com/ASKabalan/fastpm-python fastpm @ git+https://github.com/ASKabalan/fastpm-python
numpy==2.2.6