mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-30 09:01:11 +00:00
Fix sharding error (#37)
* Use cosmo as arg for the ODE function * Update examples * format * notebook update * fix tests * add correct annotations for weights in painting and warning for cic_paint in distributed pm * update test_against_fpm * update distributed tests and add jacfwd jacrev and vmap tests * format * add Caveats to notebook readme * final touches * update Growth.py to allow using FastPM solver * fix 2D painting when input is (X , Y , 2) shape * update cic read halo size and notebooks examples * Allow env variable control of caching in growth * Format * update test jax version * update notebooks/03-MultiGPU_PM_Halo.ipynb * update numpy install in wf * update tolerance :) * reorganize install in test workflow * update tests * add mpi4py * update tests.yml * update tests * update wf * format * make normal_field signature consistent with jax.random.normal * update by default normal_field dtype to match JAX * format * debug test workflow * format * debug test workflow * updating tests * fix accuracy * fixed tolerance * adding caching * Update conftest.py * Update tolerance and precision settings in distributed PM tests * revererting back changes to growth.py --------- Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
This commit is contained in:
parent
cb2a7ab17f
commit
6693e5c725
17 changed files with 675 additions and 298 deletions
34
.github/workflows/formatting.yml
vendored
34
.github/workflows/formatting.yml
vendored
|
@ -7,15 +7,37 @@ on:
|
|||
branches: [ "main" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
formatting:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
- name: Checkout Source
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-formatting-pip-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-formatting-pip-
|
||||
|
||||
- name: Cache pre-commit
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: ${{ runner.os }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pre-commit-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip isort
|
||||
python -m pip install pre-commit
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pre-commit isort
|
||||
|
||||
- name: Run pre-commit
|
||||
run: python -m pre_commit run --all-files
|
||||
|
|
54
.github/workflows/tests.yml
vendored
54
.github/workflows/tests.yml
vendored
|
@ -10,37 +10,63 @@ on:
|
|||
|
||||
jobs:
|
||||
run_tests:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10" , "3.11" , "3.12"]
|
||||
python-version: ["3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- name: Checkout Source
|
||||
uses: actions/checkout@v2.3.1
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-test.txt', '**/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-${{ matrix.python-version }}-
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Cache system dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /var/cache/apt
|
||||
key: ${{ runner.os }}-apt-${{ hashFiles('.github/workflows/tests.yml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-apt-
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libopenmpi-dev
|
||||
python -m pip install --upgrade pip
|
||||
pip install jax==0.4.35
|
||||
pip install numpy setuptools cython wheel
|
||||
pip install git+https://github.com/MP-Gadget/pfft-python
|
||||
pip install git+https://github.com/MP-Gadget/pmesh
|
||||
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
|
||||
pip install -r requirements-test.txt
|
||||
pip install .
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
# Install JAX first as it's a key dependency
|
||||
pip install jax
|
||||
# Install build dependencies
|
||||
pip install setuptools cython mpi4py
|
||||
# Install test requirements with no-build-isolation for faster builds
|
||||
pip install -r requirements-test.txt --no-build-isolation
|
||||
# Install additional test dependencies
|
||||
pip install pytest diffrax
|
||||
# Install package in development mode
|
||||
pip install -e .
|
||||
echo "numpy version installed:"
|
||||
python -c "import numpy; print(numpy.__version__)"
|
||||
|
||||
- name: Run Single Device Tests
|
||||
run: |
|
||||
cd tests
|
||||
pytest -v -m "not distributed"
|
||||
|
||||
- name: Run Distributed tests
|
||||
run: |
|
||||
pytest -v -m distributed
|
||||
pytest -v tests/test_distributed_pm.py
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue