JaxPM/notebooks/04-MultiGPU_PM_Solvers.ipynb
Francois Lanusse 70ab9f1931 remove deprecated stuff
Former-commit-id: da6fdea0556407c968dd31d44af832c997f2645c
2024-12-21 09:58:42 -05:00

3.7 MiB

In [1]:
%pylab inline
%load_ext autoreload
%autoreload 2
%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib
In [2]:
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax_cosmo as jc
In [3]:
import tensorflow as tf
import flowpm
from flowpm.tfpower import linear_matter_power
import flowpm.scipy.interpolate as interpolate
In [4]:
from jaxpm.kernels import *
from jaxpm.painting import *
In [5]:
# Below are a few parameters
box_size = [100., 100., 100.]   # Transverse comoving size of the simulation volume
nc = [100, 100, 100]            # Number of transverse voxels in the simulation volume
batch_size = 1                  # Number of simulations to run in parallel
In [6]:
# Instantiates a cosmology with desired parameters
cosmology = flowpm.cosmology.Planck15()

# Create some initial conditions
k = tf.constant(np.logspace(-4, 1, 128), dtype=tf.float32)
pk = linear_matter_power(cosmology, k)
pk_fun = lambda x: tf.cast(tf.reshape(interpolate.interp_tf(tf.reshape(tf.cast(x, tf.float32), [-1]), k, pk), x.shape), tf.complex64)
initial_conditions = flowpm.linear_field(nc,
                                      box_size, 
                                      pk_fun,         
                                      batch_size=batch_size)

initial_state = flowpm.lpt_init(cosmology, initial_conditions, 0.1)
In [73]:
@tf.function
def solve_tf(init_state):
    final_state = flowpm.nbody(cosmology, initial_state, linspace(0.1,1.,40), nc)
    return final_state
In [74]:
final_state = solve_tf(initial_state)
In [16]:
%timeit final_state = solve_tf(initial_state)
4.89 s ± 12.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [17]:
imshow(flowpm.cic_paint(tf.zeros([1]+nc), final_state[0]).numpy()[0].sum(axis=0))
Out[17]:
<matplotlib.image.AxesImage at 0x7f9774507880>
No description has been provided for this image
In [7]:
mesh_shape = nc
kvec = fftk(mesh_shape)

# Define the ODE
def f(state, a, cosmo):
    # Extracts positions and velocity at a given point in
    # the simulation
    pos, vel = state
    
    # Computes the potential given the current positions
    delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), pos))
    pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,r_split=0)
    forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos) 
                      for i in range(3)],axis=-1)
    forces = forces * 1.5 * cosmo.Omega_m
    
    # Computes the update of position (drift)
    dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
    
    # Computes the update of velocity (kick)
    dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
    
    return dpos, dvel
In [8]:
from jax.experimental.ode import odeint
In [9]:
init_state = [initial_state[0,0].numpy(),
              initial_state[1,0].numpy()]

@jax.jit
def solve_ode(init_state):
    return odeint(f, init_state, 
             jnp.linspace(0.1,1.0,10), 
             jc.Planck15(),rtol=1e-5, atol=1e-5 )
In [12]:
res = solve_ode(init_state)
In [10]:
%timeit res = solve_ode(init_state)
352 ms ± 1.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [18]:
imshow(cic_paint(jnp.zeros(mesh_shape), initial_state[0,0].numpy()).sum(axis=0))
Out[18]:
<matplotlib.image.AxesImage at 0x7f977c4261a0>
No description has been provided for this image
In [27]:
imshow(cic_paint(jnp.zeros(mesh_shape), res[0][-1]).sum(axis=0)); colorbar()
Out[27]:
<matplotlib.colorbar.Colorbar at 0x7f94181834c0>
No description has been provided for this image
In [28]:
imshow((cic_paint(jnp.zeros(mesh_shape), res[0][-1]) - 
        flowpm.cic_paint(tf.zeros([1]+mesh_shape), final_state[0]).numpy()[0]).sum(axis=0)); colorbar()
Out[28]:
<matplotlib.colorbar.Colorbar at 0x7f940073dd80>
No description has been provided for this image
In [29]:
from DifferentiableHOS.pk import power_spectrum
In [30]:
k, pk = power_spectrum(flowpm.cic_paint(tf.zeros([1]+mesh_shape), final_state[0]),
                       boxsize=np.array([100.,100.,100.]),                      
                       kmin=0.1,dk=2*np.pi/100.)

k, pk_jax = power_spectrum(tf.convert_to_tensor(cic_paint(jnp.zeros(mesh_shape), res[0][-1]).reshape([1,100,100,100])),
                       boxsize=np.array([100.,100.,100.]),                      
                       kmin=0.1,dk=2*np.pi/100.)
In [31]:
loglog(k,pk[0])
loglog(k,pk_jax[0])
Out[31]:
[<matplotlib.lines.Line2D at 0x7f94005fc850>]
No description has been provided for this image
In [32]:
semilogx(k,(pk[0] - pk_jax[0])/pk[0])
Out[32]:
[<matplotlib.lines.Line2D at 0x7f94005693f0>]
No description has been provided for this image
In [58]:
import diffrax
from diffrax import diffeqsolve, ODETerm, Dopri5, LeapfrogMidpoint, PIDController

cosmo = jc.Planck15()

def f_diffrax(t, state, args):
    return jnp.stack(f(state, t, cosmo),axis=0)

term = ODETerm(f_diffrax)
solver = Dopri5()
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
In [66]:
solution = diffeqsolve(term, solver, t0=0.1, t1=1., dt0=0.1,
                       y0=jnp.stack(init_state,axis=0), 
                       stepsize_controller=stepsize_controller)
In [67]:
solution
Out[67]:
Solution(t0=DeviceArray(0.1, dtype=float32), t1=DeviceArray(1., dtype=float32), ts=DeviceArray([1.], dtype=float32, weak_type=True), ys=DeviceArray([[[[ -2.628675  ,   5.347485  ,  -1.5904427 ],
               [ -2.5886178 ,   5.4675574 ,  -1.8807749 ],
               [ -2.6215632 ,   5.6189113 ,  -1.9770483 ],
               ...,
               [ 97.6788    , 105.22156   ,  96.64059   ],
               [ 97.459785  , 105.74062   ,  97.79143   ],
               [ 97.26484   , 105.98253   ,  98.37041   ]],

              [[ -1.7246462 ,  -0.27205068,   0.6488323 ],
               [ -1.6212564 ,  -0.29098457,   0.29932806],
               [ -1.4803871 ,  -0.176678  ,  -0.5965874 ],
               ...,
               [ -2.21784   ,   2.4735427 ,   1.3182787 ],
               [ -2.3406084 ,   2.282835  ,   1.8463919 ],
               [ -2.175532  ,   0.7863351 ,   1.1335845 ]]]],            dtype=float32, weak_type=True), interpolation=None, stats={'compiled_num_steps': None, 'max_steps': 4096, 'num_accepted_steps': DeviceArray(19, dtype=int32, weak_type=True), 'num_rejected_steps': DeviceArray(4, dtype=int32, weak_type=True), 'num_steps': DeviceArray(23, dtype=int32, weak_type=True)}, result=DeviceArray(0, dtype=int32, weak_type=True), solver_state=None, controller_state=None, made_jump=None)
In [68]:
resx = solution.ys[0,0]
In [69]:
imshow(cic_paint(jnp.zeros(mesh_shape), resx).sum(axis=0),cmap='gist_stern'); colorbar()
Out[69]:
<matplotlib.colorbar.Colorbar at 0x7f9368964cd0>
No description has been provided for this image
In [75]:
k, pk = power_spectrum(flowpm.cic_paint(tf.zeros([1]+mesh_shape), final_state[0]),
                       boxsize=array([100.,100.,100.]),                      
                       kmin=0.1,dk=2*np.pi/100.)

k, pk_jax = power_spectrum(tf.convert_to_tensor(cic_paint(jnp.zeros(mesh_shape), resx).reshape([1,100,100,100])),
                       boxsize=array([100.,100.,100.]),                      
                       kmin=0.1,dk=2*np.pi/100.)
In [76]:
loglog(k,pk[0], label='flowpm')
loglog(k,pk_jax[0], label='jax')
legend()
Out[76]:
<matplotlib.legend.Legend at 0x7f97746730d0>
No description has been provided for this image
In [77]:
semilogx(k,(pk_jax[0] - pk[0])/pk[0])
Out[77]:
[<matplotlib.lines.Line2D at 0x7f978c1fb6a0>]
No description has been provided for this image
In [ ]: