add precommit

This commit is contained in:
Wassim KABALAN 2024-07-08 00:21:33 +02:00
parent f49ed6fe70
commit 38e875a7df
2 changed files with 41 additions and 24 deletions

17
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,17 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
args: ['--parallel', '--in-place']
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)

View file

@ -4,14 +4,14 @@ This document aims to detail some of the API, implementation choices, and intern
## Objective
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
## Related Work
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
- Borg
@ -75,7 +75,7 @@ boxsize = [1024., 1024., 1024.]
nmesh = [1024, 1024, 1024]
# Create distributed frequencies
kvec = jpm.fftk(nmesh, sharding)
kvec = jpm.fftk(nmesh, sharding)
# Generate initial positions
particles = jpm.generate_initial_positions(nmesh, sharding)
@ -95,10 +95,10 @@ snapshots = jnp.linespace(0.1, 1, 10)
def high_level_api_simulation(cosmo, kvec, particles)
# Initial conditions is a distributed 3D array
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
# Create a particular solver
solver = jpm.solvers.fastpm(cosmo, B=1)
solver = jpm.solvers.fastpm(cosmo, B=1)
# Initialize and run the simulation
state = solver.init(initial_conditions)
state = solver.nbody(state) # Will use base leapfrog integrator
@ -106,12 +106,12 @@ def high_level_api_simulation(cosmo, kvec, particles)
diffrax_solver = diffrax.Dopri5()
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
state = solver.nbody(state , diffrax_solver, stepsize_controller, t0=0.1, t1=1 , dt=0.01 , s=snapshots)
# Painting the results
# User defined function to compute weights
weights = get_weights(particles)
density = jpm.cic_paint(state.positions , weights)
return density
with spmd_config:
@ -122,17 +122,17 @@ with spmd_config:
def mid_level_api_simulation(cosmo, kvec, particles)
# Initial conditions is a distributed 3D array
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
# Create a particular solver
solver = jpm.solvers.fastpm(cosmo, B=1)
solver = jpm.solvers.fastpm(cosmo, B=1)
# Initialize and run the simulation
state = solver.init(initial_conditions , kvec)
state = solver.lpt(state , a=0.1)
# OR
state = solver.lpt2(state , a=0.1) # Does both LPT1 and LPT2
# Run the nbody simulation
state = solver.Euler(state , t0=0.1, t1=1 , dt=0.01)
# OR
@ -145,7 +145,7 @@ def mid_level_api_simulation(cosmo, kvec, particles)
term = ODETerm(lambda t, state, args: ode_fn(state, t, args))
diffrax_solver = diffrax.Dopri5()
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
ode_solution = diffrax.diffeqsolve(term,
diffrax_solver,
t0=0.1,
@ -155,12 +155,12 @@ def mid_level_api_simulation(cosmo, kvec, particles)
saveat=SaveAt(t0=False,t1=True,ts=snapshots),
args=cosmo,
stepsize_controller=stepsize_controller)
final_state = ode_solution.ys[-1]
# User defined function to compute weights
weights = get_weights(particles)
density = jpm.cic_paint(final_state.positions , weights)
return state
with spmd_config:
@ -171,7 +171,7 @@ with spmd_config:
def low_level_api_simulation(cosmo, kvec, particles, a=0.1)
# Initial conditions is a distributed 3D array
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
# Create a particular solver
initial_force = pm_forces(inital_conditions,nmesh)
# First order LPT
@ -189,26 +189,26 @@ def low_level_api_simulation(cosmo, kvec, particles, a=0.1)
p += p2
state = jpm.empty_state(dx , p , kvec)
def nbody_ode(state, a, cosmo):
pos, vel , kvec= state.positions, state.velocities , state.kvec
forces = pm_forces(pos, mesh_shape=nmesh) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel
term = ODETerm(lambda t, state, args: nbody_ode(state, t, args))
diffrax_solver = diffrax.Dopri5()
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
ode_solution = diffeqsolve(term,
diffrax_solver,
t0=0.1,
@ -223,7 +223,7 @@ def low_level_api_simulation(cosmo, kvec, particles, a=0.1)
# User defined function to compute weights
weights = get_weights(particles)
density = jpm.cic_paint(final_state.positions , weights)
return state
with spmd_config:
@ -237,4 +237,4 @@ Users can also mix and match the different levels of API to create their own cus
### TODOs
- [ ] Implement tsc_paint and tsc_compenstate
- [ ] Implement a distributed power spectrum calculation
- [ ] Implement a distributed power spectrum calculation