mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
add precommit
This commit is contained in:
parent
f49ed6fe70
commit
38e875a7df
2 changed files with 41 additions and 24 deletions
17
.pre-commit-config.yaml
Normal file
17
.pre-commit-config.yaml
Normal file
|
@ -0,0 +1,17 @@
|
|||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.3.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.40.2
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: ['--parallel', '--in-place']
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
48
design.md
48
design.md
|
@ -4,14 +4,14 @@ This document aims to detail some of the API, implementation choices, and intern
|
|||
|
||||
## Objective
|
||||
|
||||
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
|
||||
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
|
||||
|
||||
## Related Work
|
||||
|
||||
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
|
||||
|
||||
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
|
||||
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
|
||||
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
|
||||
- Borg
|
||||
|
||||
|
||||
|
@ -75,7 +75,7 @@ boxsize = [1024., 1024., 1024.]
|
|||
nmesh = [1024, 1024, 1024]
|
||||
|
||||
# Create distributed frequencies
|
||||
kvec = jpm.fftk(nmesh, sharding)
|
||||
kvec = jpm.fftk(nmesh, sharding)
|
||||
# Generate initial positions
|
||||
particles = jpm.generate_initial_positions(nmesh, sharding)
|
||||
|
||||
|
@ -95,10 +95,10 @@ snapshots = jnp.linespace(0.1, 1, 10)
|
|||
def high_level_api_simulation(cosmo, kvec, particles)
|
||||
# Initial conditions is a distributed 3D array
|
||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
||||
|
||||
|
||||
# Create a particular solver
|
||||
solver = jpm.solvers.fastpm(cosmo, B=1)
|
||||
|
||||
solver = jpm.solvers.fastpm(cosmo, B=1)
|
||||
|
||||
# Initialize and run the simulation
|
||||
state = solver.init(initial_conditions)
|
||||
state = solver.nbody(state) # Will use base leapfrog integrator
|
||||
|
@ -106,12 +106,12 @@ def high_level_api_simulation(cosmo, kvec, particles)
|
|||
diffrax_solver = diffrax.Dopri5()
|
||||
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
||||
state = solver.nbody(state , diffrax_solver, stepsize_controller, t0=0.1, t1=1 , dt=0.01 , s=snapshots)
|
||||
|
||||
|
||||
# Painting the results
|
||||
# User defined function to compute weights
|
||||
weights = get_weights(particles)
|
||||
density = jpm.cic_paint(state.positions , weights)
|
||||
|
||||
|
||||
return density
|
||||
|
||||
with spmd_config:
|
||||
|
@ -122,17 +122,17 @@ with spmd_config:
|
|||
def mid_level_api_simulation(cosmo, kvec, particles)
|
||||
# Initial conditions is a distributed 3D array
|
||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
||||
|
||||
|
||||
# Create a particular solver
|
||||
solver = jpm.solvers.fastpm(cosmo, B=1)
|
||||
|
||||
solver = jpm.solvers.fastpm(cosmo, B=1)
|
||||
|
||||
# Initialize and run the simulation
|
||||
state = solver.init(initial_conditions , kvec)
|
||||
|
||||
state = solver.lpt(state , a=0.1)
|
||||
# OR
|
||||
state = solver.lpt2(state , a=0.1) # Does both LPT1 and LPT2
|
||||
|
||||
|
||||
# Run the nbody simulation
|
||||
state = solver.Euler(state , t0=0.1, t1=1 , dt=0.01)
|
||||
# OR
|
||||
|
@ -145,7 +145,7 @@ def mid_level_api_simulation(cosmo, kvec, particles)
|
|||
term = ODETerm(lambda t, state, args: ode_fn(state, t, args))
|
||||
diffrax_solver = diffrax.Dopri5()
|
||||
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
||||
|
||||
|
||||
ode_solution = diffrax.diffeqsolve(term,
|
||||
diffrax_solver,
|
||||
t0=0.1,
|
||||
|
@ -155,12 +155,12 @@ def mid_level_api_simulation(cosmo, kvec, particles)
|
|||
saveat=SaveAt(t0=False,t1=True,ts=snapshots),
|
||||
args=cosmo,
|
||||
stepsize_controller=stepsize_controller)
|
||||
|
||||
|
||||
final_state = ode_solution.ys[-1]
|
||||
# User defined function to compute weights
|
||||
weights = get_weights(particles)
|
||||
density = jpm.cic_paint(final_state.positions , weights)
|
||||
|
||||
|
||||
return state
|
||||
|
||||
with spmd_config:
|
||||
|
@ -171,7 +171,7 @@ with spmd_config:
|
|||
def low_level_api_simulation(cosmo, kvec, particles, a=0.1)
|
||||
# Initial conditions is a distributed 3D array
|
||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
||||
|
||||
|
||||
# Create a particular solver
|
||||
initial_force = pm_forces(inital_conditions,nmesh)
|
||||
# First order LPT
|
||||
|
@ -189,26 +189,26 @@ def low_level_api_simulation(cosmo, kvec, particles, a=0.1)
|
|||
p += p2
|
||||
|
||||
state = jpm.empty_state(dx , p , kvec)
|
||||
|
||||
|
||||
def nbody_ode(state, a, cosmo):
|
||||
|
||||
pos, vel , kvec= state.positions, state.velocities , state.kvec
|
||||
|
||||
|
||||
forces = pm_forces(pos, mesh_shape=nmesh) * 1.5 * cosmo.Omega_m
|
||||
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
||||
|
||||
# Computes the update of velocity (kick)
|
||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
|
||||
return dpos, dvel
|
||||
|
||||
term = ODETerm(lambda t, state, args: nbody_ode(state, t, args))
|
||||
diffrax_solver = diffrax.Dopri5()
|
||||
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
||||
|
||||
|
||||
|
||||
ode_solution = diffeqsolve(term,
|
||||
diffrax_solver,
|
||||
t0=0.1,
|
||||
|
@ -223,7 +223,7 @@ def low_level_api_simulation(cosmo, kvec, particles, a=0.1)
|
|||
# User defined function to compute weights
|
||||
weights = get_weights(particles)
|
||||
density = jpm.cic_paint(final_state.positions , weights)
|
||||
|
||||
|
||||
return state
|
||||
|
||||
with spmd_config:
|
||||
|
@ -237,4 +237,4 @@ Users can also mix and match the different levels of API to create their own cus
|
|||
### TODOs
|
||||
|
||||
- [ ] Implement tsc_paint and tsc_compenstate
|
||||
- [ ] Implement a distributed power spectrum calculation
|
||||
- [ ] Implement a distributed power spectrum calculation
|
||||
|
|
Loading…
Add table
Reference in a new issue