JaxPM/.github/workflows/tests.yml
Wassim KABALAN 6693e5c725
Some checks failed
Code Formatting / formatting (push) Failing after 4m30s
Tests / run_tests (3.10) (push) Failing after 1m41s
Tests / run_tests (3.11) (push) Failing after 1m42s
Tests / run_tests (3.12) (push) Failing after 1m15s
Fix sharding error (#37)
* Use cosmo as arg for the ODE function

* Update examples

* format

* notebook update

* fix tests

* add correct annotations for weights in painting and warning for cic_paint in distributed pm

* update test_against_fpm

* update distributed tests and add jacfwd jacrev and vmap tests

* format

* add Caveats to notebook readme

* final touches

* update Growth.py to allow using FastPM solver

* fix 2D painting when input is (X , Y , 2) shape

* update cic read halo size and notebooks examples

* Allow env variable control of caching in growth

* Format

* update test jax version

* update notebooks/03-MultiGPU_PM_Halo.ipynb

* update numpy install in wf

* update tolerance :)

* reorganize install in test workflow

* update tests

* add mpi4py

* update tests.yml

* update tests

* update wf

* format

* make normal_field signature consistent with jax.random.normal

* update by default normal_field dtype to match JAX

* format

* debug test workflow

* format

* debug test workflow

* updating tests

* fix accuracy

* fixed tolerance

* adding caching

* Update conftest.py

* Update tolerance and precision settings in distributed PM tests

* revererting back changes to growth.py

---------

Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
2025-06-28 23:07:31 +02:00

72 lines
2 KiB
YAML

name: Tests
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
run_tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
steps:
- name: Checkout Source
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-test.txt', '**/pyproject.toml') }}
restore-keys: |
${{ runner.os }}-pip-${{ matrix.python-version }}-
${{ runner.os }}-pip-
- name: Cache system dependencies
uses: actions/cache@v4
with:
path: /var/cache/apt
key: ${{ runner.os }}-apt-${{ hashFiles('.github/workflows/tests.yml') }}
restore-keys: |
${{ runner.os }}-apt-
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y libopenmpi-dev
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
# Install JAX first as it's a key dependency
pip install jax
# Install build dependencies
pip install setuptools cython mpi4py
# Install test requirements with no-build-isolation for faster builds
pip install -r requirements-test.txt --no-build-isolation
# Install additional test dependencies
pip install pytest diffrax
# Install package in development mode
pip install -e .
echo "numpy version installed:"
python -c "import numpy; print(numpy.__version__)"
- name: Run Single Device Tests
run: |
cd tests
pytest -v -m "not distributed"
- name: Run Distributed tests
run: |
pytest -v tests/test_distributed_pm.py