diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f44eaca --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/google/yapf + rev: v0.40.2 + hooks: + - id: yapf + args: ['--parallel', '--in-place'] +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) diff --git a/design.md b/design.md index 15ec417..fc2be6a 100644 --- a/design.md +++ b/design.md @@ -4,14 +4,14 @@ This document aims to detail some of the API, implementation choices, and intern ## Objective -Provide a user-friendly framework for distributed Particle-Mesh N-body simulations. +Provide a user-friendly framework for distributed Particle-Mesh N-body simulations. ## Related Work This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models. - [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow -- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD +- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD - Borg @@ -75,7 +75,7 @@ boxsize = [1024., 1024., 1024.] nmesh = [1024, 1024, 1024] # Create distributed frequencies -kvec = jpm.fftk(nmesh, sharding) +kvec = jpm.fftk(nmesh, sharding) # Generate initial positions particles = jpm.generate_initial_positions(nmesh, sharding) @@ -95,10 +95,10 @@ snapshots = jnp.linespace(0.1, 1, 10) def high_level_api_simulation(cosmo, kvec, particles) # Initial conditions is a distributed 3D array inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32') - + # Create a particular solver - solver = jpm.solvers.fastpm(cosmo, B=1) - + solver = jpm.solvers.fastpm(cosmo, B=1) + # Initialize and run the simulation state = solver.init(initial_conditions) state = solver.nbody(state) # Will use base leapfrog integrator @@ -106,12 +106,12 @@ def high_level_api_simulation(cosmo, kvec, particles) diffrax_solver = diffrax.Dopri5() stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5) state = solver.nbody(state , diffrax_solver, stepsize_controller, t0=0.1, t1=1 , dt=0.01 , s=snapshots) - + # Painting the results # User defined function to compute weights weights = get_weights(particles) density = jpm.cic_paint(state.positions , weights) - + return density with spmd_config: @@ -122,17 +122,17 @@ with spmd_config: def mid_level_api_simulation(cosmo, kvec, particles) # Initial conditions is a distributed 3D array inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32') - + # Create a particular solver - solver = jpm.solvers.fastpm(cosmo, B=1) - + solver = jpm.solvers.fastpm(cosmo, B=1) + # Initialize and run the simulation state = solver.init(initial_conditions , kvec) state = solver.lpt(state , a=0.1) # OR state = solver.lpt2(state , a=0.1) # Does both LPT1 and LPT2 - + # Run the nbody simulation state = solver.Euler(state , t0=0.1, t1=1 , dt=0.01) # OR @@ -145,7 +145,7 @@ def mid_level_api_simulation(cosmo, kvec, particles) term = ODETerm(lambda t, state, args: ode_fn(state, t, args)) diffrax_solver = diffrax.Dopri5() stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5) - + ode_solution = diffrax.diffeqsolve(term, diffrax_solver, t0=0.1, @@ -155,12 +155,12 @@ def mid_level_api_simulation(cosmo, kvec, particles) saveat=SaveAt(t0=False,t1=True,ts=snapshots), args=cosmo, stepsize_controller=stepsize_controller) - + final_state = ode_solution.ys[-1] # User defined function to compute weights weights = get_weights(particles) density = jpm.cic_paint(final_state.positions , weights) - + return state with spmd_config: @@ -171,7 +171,7 @@ with spmd_config: def low_level_api_simulation(cosmo, kvec, particles, a=0.1) # Initial conditions is a distributed 3D array inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32') - + # Create a particular solver initial_force = pm_forces(inital_conditions,nmesh) # First order LPT @@ -189,26 +189,26 @@ def low_level_api_simulation(cosmo, kvec, particles, a=0.1) p += p2 state = jpm.empty_state(dx , p , kvec) - + def nbody_ode(state, a, cosmo): pos, vel , kvec= state.positions, state.velocities , state.kvec - + forces = pm_forces(pos, mesh_shape=nmesh) * 1.5 * cosmo.Omega_m - + # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel - + # Computes the update of velocity (kick) dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces - + return dpos, dvel term = ODETerm(lambda t, state, args: nbody_ode(state, t, args)) diffrax_solver = diffrax.Dopri5() stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5) - + ode_solution = diffeqsolve(term, diffrax_solver, t0=0.1, @@ -223,7 +223,7 @@ def low_level_api_simulation(cosmo, kvec, particles, a=0.1) # User defined function to compute weights weights = get_weights(particles) density = jpm.cic_paint(final_state.positions , weights) - + return state with spmd_config: @@ -237,4 +237,4 @@ Users can also mix and match the different levels of API to create their own cus ### TODOs - [ ] Implement tsc_paint and tsc_compenstate -- [ ] Implement a distributed power spectrum calculation \ No newline at end of file +- [ ] Implement a distributed power spectrum calculation