name: Tests on: push: branches: - main pull_request: branches: - main jobs: run_tests: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.10", "3.11", "3.12"] steps: - name: Checkout Source uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Cache pip dependencies uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-test.txt', '**/pyproject.toml') }} restore-keys: | ${{ runner.os }}-pip-${{ matrix.python-version }}- ${{ runner.os }}-pip- - name: Cache system dependencies uses: actions/cache@v4 with: path: /var/cache/apt key: ${{ runner.os }}-apt-${{ hashFiles('.github/workflows/tests.yml') }} restore-keys: | ${{ runner.os }}-apt- - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y libopenmpi-dev - name: Install Python dependencies run: | python -m pip install --upgrade pip setuptools wheel # Install JAX first as it's a key dependency pip install jax # Install build dependencies pip install setuptools cython mpi4py # Install test requirements with no-build-isolation for faster builds pip install -r requirements-test.txt --no-build-isolation # Install additional test dependencies pip install pytest diffrax # Install package in development mode pip install -e . echo "numpy version installed:" python -c "import numpy; print(numpy.__version__)" - name: Run Single Device Tests run: | cd tests pytest -v -m "not distributed" - name: Run Distributed tests run: | pytest -v tests/test_distributed_pm.py