mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
add precommit
This commit is contained in:
parent
f49ed6fe70
commit
38e875a7df
2 changed files with 41 additions and 24 deletions
17
.pre-commit-config.yaml
Normal file
17
.pre-commit-config.yaml
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v2.3.0
|
||||||
|
hooks:
|
||||||
|
- id: check-yaml
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- repo: https://github.com/google/yapf
|
||||||
|
rev: v0.40.2
|
||||||
|
hooks:
|
||||||
|
- id: yapf
|
||||||
|
args: ['--parallel', '--in-place']
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: 5.13.2
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
name: isort (python)
|
48
design.md
48
design.md
|
@ -4,14 +4,14 @@ This document aims to detail some of the API, implementation choices, and intern
|
||||||
|
|
||||||
## Objective
|
## Objective
|
||||||
|
|
||||||
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
|
Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
|
||||||
|
|
||||||
## Related Work
|
## Related Work
|
||||||
|
|
||||||
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
|
This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
|
||||||
|
|
||||||
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
|
- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
|
||||||
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
|
- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
|
||||||
- Borg
|
- Borg
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ boxsize = [1024., 1024., 1024.]
|
||||||
nmesh = [1024, 1024, 1024]
|
nmesh = [1024, 1024, 1024]
|
||||||
|
|
||||||
# Create distributed frequencies
|
# Create distributed frequencies
|
||||||
kvec = jpm.fftk(nmesh, sharding)
|
kvec = jpm.fftk(nmesh, sharding)
|
||||||
# Generate initial positions
|
# Generate initial positions
|
||||||
particles = jpm.generate_initial_positions(nmesh, sharding)
|
particles = jpm.generate_initial_positions(nmesh, sharding)
|
||||||
|
|
||||||
|
@ -95,10 +95,10 @@ snapshots = jnp.linespace(0.1, 1, 10)
|
||||||
def high_level_api_simulation(cosmo, kvec, particles)
|
def high_level_api_simulation(cosmo, kvec, particles)
|
||||||
# Initial conditions is a distributed 3D array
|
# Initial conditions is a distributed 3D array
|
||||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
||||||
|
|
||||||
# Create a particular solver
|
# Create a particular solver
|
||||||
solver = jpm.solvers.fastpm(cosmo, B=1)
|
solver = jpm.solvers.fastpm(cosmo, B=1)
|
||||||
|
|
||||||
# Initialize and run the simulation
|
# Initialize and run the simulation
|
||||||
state = solver.init(initial_conditions)
|
state = solver.init(initial_conditions)
|
||||||
state = solver.nbody(state) # Will use base leapfrog integrator
|
state = solver.nbody(state) # Will use base leapfrog integrator
|
||||||
|
@ -106,12 +106,12 @@ def high_level_api_simulation(cosmo, kvec, particles)
|
||||||
diffrax_solver = diffrax.Dopri5()
|
diffrax_solver = diffrax.Dopri5()
|
||||||
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
||||||
state = solver.nbody(state , diffrax_solver, stepsize_controller, t0=0.1, t1=1 , dt=0.01 , s=snapshots)
|
state = solver.nbody(state , diffrax_solver, stepsize_controller, t0=0.1, t1=1 , dt=0.01 , s=snapshots)
|
||||||
|
|
||||||
# Painting the results
|
# Painting the results
|
||||||
# User defined function to compute weights
|
# User defined function to compute weights
|
||||||
weights = get_weights(particles)
|
weights = get_weights(particles)
|
||||||
density = jpm.cic_paint(state.positions , weights)
|
density = jpm.cic_paint(state.positions , weights)
|
||||||
|
|
||||||
return density
|
return density
|
||||||
|
|
||||||
with spmd_config:
|
with spmd_config:
|
||||||
|
@ -122,17 +122,17 @@ with spmd_config:
|
||||||
def mid_level_api_simulation(cosmo, kvec, particles)
|
def mid_level_api_simulation(cosmo, kvec, particles)
|
||||||
# Initial conditions is a distributed 3D array
|
# Initial conditions is a distributed 3D array
|
||||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
||||||
|
|
||||||
# Create a particular solver
|
# Create a particular solver
|
||||||
solver = jpm.solvers.fastpm(cosmo, B=1)
|
solver = jpm.solvers.fastpm(cosmo, B=1)
|
||||||
|
|
||||||
# Initialize and run the simulation
|
# Initialize and run the simulation
|
||||||
state = solver.init(initial_conditions , kvec)
|
state = solver.init(initial_conditions , kvec)
|
||||||
|
|
||||||
state = solver.lpt(state , a=0.1)
|
state = solver.lpt(state , a=0.1)
|
||||||
# OR
|
# OR
|
||||||
state = solver.lpt2(state , a=0.1) # Does both LPT1 and LPT2
|
state = solver.lpt2(state , a=0.1) # Does both LPT1 and LPT2
|
||||||
|
|
||||||
# Run the nbody simulation
|
# Run the nbody simulation
|
||||||
state = solver.Euler(state , t0=0.1, t1=1 , dt=0.01)
|
state = solver.Euler(state , t0=0.1, t1=1 , dt=0.01)
|
||||||
# OR
|
# OR
|
||||||
|
@ -145,7 +145,7 @@ def mid_level_api_simulation(cosmo, kvec, particles)
|
||||||
term = ODETerm(lambda t, state, args: ode_fn(state, t, args))
|
term = ODETerm(lambda t, state, args: ode_fn(state, t, args))
|
||||||
diffrax_solver = diffrax.Dopri5()
|
diffrax_solver = diffrax.Dopri5()
|
||||||
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
||||||
|
|
||||||
ode_solution = diffrax.diffeqsolve(term,
|
ode_solution = diffrax.diffeqsolve(term,
|
||||||
diffrax_solver,
|
diffrax_solver,
|
||||||
t0=0.1,
|
t0=0.1,
|
||||||
|
@ -155,12 +155,12 @@ def mid_level_api_simulation(cosmo, kvec, particles)
|
||||||
saveat=SaveAt(t0=False,t1=True,ts=snapshots),
|
saveat=SaveAt(t0=False,t1=True,ts=snapshots),
|
||||||
args=cosmo,
|
args=cosmo,
|
||||||
stepsize_controller=stepsize_controller)
|
stepsize_controller=stepsize_controller)
|
||||||
|
|
||||||
final_state = ode_solution.ys[-1]
|
final_state = ode_solution.ys[-1]
|
||||||
# User defined function to compute weights
|
# User defined function to compute weights
|
||||||
weights = get_weights(particles)
|
weights = get_weights(particles)
|
||||||
density = jpm.cic_paint(final_state.positions , weights)
|
density = jpm.cic_paint(final_state.positions , weights)
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
with spmd_config:
|
with spmd_config:
|
||||||
|
@ -171,7 +171,7 @@ with spmd_config:
|
||||||
def low_level_api_simulation(cosmo, kvec, particles, a=0.1)
|
def low_level_api_simulation(cosmo, kvec, particles, a=0.1)
|
||||||
# Initial conditions is a distributed 3D array
|
# Initial conditions is a distributed 3D array
|
||||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
||||||
|
|
||||||
# Create a particular solver
|
# Create a particular solver
|
||||||
initial_force = pm_forces(inital_conditions,nmesh)
|
initial_force = pm_forces(inital_conditions,nmesh)
|
||||||
# First order LPT
|
# First order LPT
|
||||||
|
@ -189,26 +189,26 @@ def low_level_api_simulation(cosmo, kvec, particles, a=0.1)
|
||||||
p += p2
|
p += p2
|
||||||
|
|
||||||
state = jpm.empty_state(dx , p , kvec)
|
state = jpm.empty_state(dx , p , kvec)
|
||||||
|
|
||||||
def nbody_ode(state, a, cosmo):
|
def nbody_ode(state, a, cosmo):
|
||||||
|
|
||||||
pos, vel , kvec= state.positions, state.velocities , state.kvec
|
pos, vel , kvec= state.positions, state.velocities , state.kvec
|
||||||
|
|
||||||
forces = pm_forces(pos, mesh_shape=nmesh) * 1.5 * cosmo.Omega_m
|
forces = pm_forces(pos, mesh_shape=nmesh) * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
# Computes the update of position (drift)
|
# Computes the update of position (drift)
|
||||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||||
|
|
||||||
# Computes the update of velocity (kick)
|
# Computes the update of velocity (kick)
|
||||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
return dpos, dvel
|
return dpos, dvel
|
||||||
|
|
||||||
term = ODETerm(lambda t, state, args: nbody_ode(state, t, args))
|
term = ODETerm(lambda t, state, args: nbody_ode(state, t, args))
|
||||||
diffrax_solver = diffrax.Dopri5()
|
diffrax_solver = diffrax.Dopri5()
|
||||||
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
ode_solution = diffeqsolve(term,
|
ode_solution = diffeqsolve(term,
|
||||||
diffrax_solver,
|
diffrax_solver,
|
||||||
t0=0.1,
|
t0=0.1,
|
||||||
|
@ -223,7 +223,7 @@ def low_level_api_simulation(cosmo, kvec, particles, a=0.1)
|
||||||
# User defined function to compute weights
|
# User defined function to compute weights
|
||||||
weights = get_weights(particles)
|
weights = get_weights(particles)
|
||||||
density = jpm.cic_paint(final_state.positions , weights)
|
density = jpm.cic_paint(final_state.positions , weights)
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
with spmd_config:
|
with spmd_config:
|
||||||
|
@ -237,4 +237,4 @@ Users can also mix and match the different levels of API to create their own cus
|
||||||
### TODOs
|
### TODOs
|
||||||
|
|
||||||
- [ ] Implement tsc_paint and tsc_compenstate
|
- [ ] Implement tsc_paint and tsc_compenstate
|
||||||
- [ ] Implement a distributed power spectrum calculation
|
- [ ] Implement a distributed power spectrum calculation
|
||||||
|
|
Loading…
Add table
Reference in a new issue