mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
Add design
This commit is contained in:
parent
5f463450d1
commit
f49ed6fe70
1 changed files with 188 additions and 0 deletions
188
design.md
188
design.md
|
@ -50,3 +50,191 @@ state = solver.nbody(state)
|
|||
density = jpm.zeros(boxsize, nmesh)
|
||||
density = jpm.paint(density, state.positions)
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Distributed implementation
|
||||
|
||||
|
||||
```python
|
||||
import jaxpm as jpm
|
||||
from jpm import SPMDConfig
|
||||
import jax_cosmo as jc
|
||||
import jax.numpy as jnp
|
||||
import diffrax
|
||||
|
||||
jax.distributed.initialize()
|
||||
|
||||
pdims = (4, 4) # Got this from autotuning
|
||||
devices = mesh_utils.create_device_mesh(pdims)
|
||||
mesh = Mesh(devices, axis_names=('y', 'z'))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('z', 'y'))
|
||||
|
||||
# Creates initial conditions
|
||||
boxsize = [1024., 1024., 1024.]
|
||||
nmesh = [1024, 1024, 1024]
|
||||
|
||||
# Create distributed frequencies
|
||||
kvec = jpm.fftk(nmesh, sharding)
|
||||
# Generate initial positions
|
||||
particles = jpm.generate_initial_positions(nmesh, sharding)
|
||||
|
||||
# This contains the mesh, pdims that are necessary to create the shard_mapped functions
|
||||
spmd_config = SPMDConfig(sharding)
|
||||
|
||||
# fftk and generate_initial_positions cannot be jitted because otherwise
|
||||
# We will be closing on non addressable array which is not allowed
|
||||
# https://github.com/google/jax/issues/22218
|
||||
|
||||
# Instantiate differentiable cosmology object
|
||||
cosmo = jc.Planck(Omega_c=0.3,, sigma8= 0.8)
|
||||
snapshots = jnp.linespace(0.1, 1, 10)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def high_level_api_simulation(cosmo, kvec, particles)
|
||||
# Initial conditions is a distributed 3D array
|
||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
||||
|
||||
# Create a particular solver
|
||||
solver = jpm.solvers.fastpm(cosmo, B=1)
|
||||
|
||||
# Initialize and run the simulation
|
||||
state = solver.init(initial_conditions)
|
||||
state = solver.nbody(state) # Will use base leapfrog integrator
|
||||
# OR
|
||||
diffrax_solver = diffrax.Dopri5()
|
||||
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
||||
state = solver.nbody(state , diffrax_solver, stepsize_controller, t0=0.1, t1=1 , dt=0.01 , s=snapshots)
|
||||
|
||||
# Painting the results
|
||||
# User defined function to compute weights
|
||||
weights = get_weights(particles)
|
||||
density = jpm.cic_paint(state.positions , weights)
|
||||
|
||||
return density
|
||||
|
||||
with spmd_config:
|
||||
density = high_level_api_simulation(cosmo)
|
||||
density_dt = jax.grad(high_level_api_simulation)(cosmo)
|
||||
|
||||
@jax.jit
|
||||
def mid_level_api_simulation(cosmo, kvec, particles)
|
||||
# Initial conditions is a distributed 3D array
|
||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
||||
|
||||
# Create a particular solver
|
||||
solver = jpm.solvers.fastpm(cosmo, B=1)
|
||||
|
||||
# Initialize and run the simulation
|
||||
state = solver.init(initial_conditions , kvec)
|
||||
|
||||
state = solver.lpt(state , a=0.1)
|
||||
# OR
|
||||
state = solver.lpt2(state , a=0.1) # Does both LPT1 and LPT2
|
||||
|
||||
# Run the nbody simulation
|
||||
state = solver.Euler(state , t0=0.1, t1=1 , dt=0.01)
|
||||
# OR
|
||||
state = solver.Euler(state , t0=0.1, t1=1 , ts=snaphots)
|
||||
|
||||
# Also provide a leapfrog integrator
|
||||
|
||||
# Or use external integrator
|
||||
ode_fn = jpm.make_ode_fn(nmesh)
|
||||
term = ODETerm(lambda t, state, args: ode_fn(state, t, args))
|
||||
diffrax_solver = diffrax.Dopri5()
|
||||
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
||||
|
||||
ode_solution = diffrax.diffeqsolve(term,
|
||||
diffrax_solver,
|
||||
t0=0.1,
|
||||
t1=1.,
|
||||
dt0=0.01,
|
||||
y0=state,
|
||||
saveat=SaveAt(t0=False,t1=True,ts=snapshots),
|
||||
args=cosmo,
|
||||
stepsize_controller=stepsize_controller)
|
||||
|
||||
final_state = ode_solution.ys[-1]
|
||||
# User defined function to compute weights
|
||||
weights = get_weights(particles)
|
||||
density = jpm.cic_paint(final_state.positions , weights)
|
||||
|
||||
return state
|
||||
|
||||
with spmd_config:
|
||||
density = mid_level_api_simulation(cosmo)
|
||||
density_dt = jax.grad(mid_level_api_simulation)(cosmo)
|
||||
|
||||
@jax.jit
|
||||
def low_level_api_simulation(cosmo, kvec, particles, a=0.1)
|
||||
# Initial conditions is a distributed 3D array
|
||||
inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh,sharding, kvec , dtype='float32')
|
||||
|
||||
# Create a particular solver
|
||||
initial_force = pm_forces(inital_conditions,nmesh)
|
||||
# First order LPT
|
||||
a = jnp.atleast_1d(a)
|
||||
dx = growth_factor(cosmo, a) * initial_force
|
||||
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
|
||||
# LPT2
|
||||
delta2 = jpm.generate_2lpt(inital_conditions, nmesh,kvec)
|
||||
init_force2 = pm_forces(delta2,nmesh)
|
||||
# Taken from Hugo Simon
|
||||
dx2 = 3/7 * growth_factor_second(cosmo, a) * init_force2 # D2 is renormalized: - D2 = 3/7 * growth_factor_second
|
||||
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
|
||||
|
||||
dx += dx2
|
||||
p += p2
|
||||
|
||||
state = jpm.empty_state(dx , p , kvec)
|
||||
|
||||
def nbody_ode(state, a, cosmo):
|
||||
|
||||
pos, vel , kvec= state.positions, state.velocities , state.kvec
|
||||
|
||||
forces = pm_forces(pos, mesh_shape=nmesh) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
||||
# Computes the update of velocity (kick)
|
||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
return dpos, dvel
|
||||
|
||||
term = ODETerm(lambda t, state, args: nbody_ode(state, t, args))
|
||||
diffrax_solver = diffrax.Dopri5()
|
||||
stepsize_controller = diffrax.PIDController(rtol=1e-5,atol=1e-5)
|
||||
|
||||
|
||||
ode_solution = diffeqsolve(term,
|
||||
diffrax_solver,
|
||||
t0=0.1,
|
||||
t1=1.,
|
||||
dt0=0.01,
|
||||
y0=state,
|
||||
saveat=SaveAt(t0=False,t1=True,ts=snapshots),
|
||||
args=cosmo,
|
||||
stepsize_controller=stepsize_controller)
|
||||
|
||||
final_state = ode_solution.ys[-1]
|
||||
# User defined function to compute weights
|
||||
weights = get_weights(particles)
|
||||
density = jpm.cic_paint(final_state.positions , weights)
|
||||
|
||||
return state
|
||||
|
||||
with spmd_config:
|
||||
density = low_level_api_simulation(cosmo)
|
||||
density_dt = jax.grad(low_level_api_simulation)(cosmo)
|
||||
|
||||
```
|
||||
|
||||
Users can also mix and match the different levels of API to create their own custom simulations.
|
||||
|
||||
### TODOs
|
||||
|
||||
- [ ] Implement tsc_paint and tsc_compenstate
|
||||
- [ ] Implement a distributed power spectrum calculation
|
Loading…
Add table
Reference in a new issue