mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
3.7 MiB
3.7 MiB
In [1]:
%pylab inline
%load_ext autoreload
%autoreload 2
In [2]:
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax_cosmo as jc
In [3]:
import tensorflow as tf
import flowpm
from flowpm.tfpower import linear_matter_power
import flowpm.scipy.interpolate as interpolate
In [4]:
from jaxpm.kernels import *
from jaxpm.painting import *
In [5]:
# Below are a few parameters
box_size = [100., 100., 100.] # Transverse comoving size of the simulation volume
nc = [100, 100, 100] # Number of transverse voxels in the simulation volume
batch_size = 1 # Number of simulations to run in parallel
In [6]:
# Instantiates a cosmology with desired parameters
cosmology = flowpm.cosmology.Planck15()
# Create some initial conditions
k = tf.constant(np.logspace(-4, 1, 128), dtype=tf.float32)
pk = linear_matter_power(cosmology, k)
pk_fun = lambda x: tf.cast(tf.reshape(interpolate.interp_tf(tf.reshape(tf.cast(x, tf.float32), [-1]), k, pk), x.shape), tf.complex64)
initial_conditions = flowpm.linear_field(nc,
box_size,
pk_fun,
batch_size=batch_size)
initial_state = flowpm.lpt_init(cosmology, initial_conditions, 0.1)
In [73]:
@tf.function
def solve_tf(init_state):
final_state = flowpm.nbody(cosmology, initial_state, linspace(0.1,1.,40), nc)
return final_state
In [74]:
final_state = solve_tf(initial_state)
In [16]:
%timeit final_state = solve_tf(initial_state)
In [17]:
imshow(flowpm.cic_paint(tf.zeros([1]+nc), final_state[0]).numpy()[0].sum(axis=0))
Out[17]:
In [7]:
mesh_shape = nc
kvec = fftk(mesh_shape)
# Define the ODE
def f(state, a, cosmo):
# Extracts positions and velocity at a given point in
# the simulation
pos, vel = state
# Computes the potential given the current positions
delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), pos))
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,r_split=0)
forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos)
for i in range(3)],axis=-1)
forces = forces * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel
In [8]:
from jax.experimental.ode import odeint
In [9]:
init_state = [initial_state[0,0].numpy(),
initial_state[1,0].numpy()]
@jax.jit
def solve_ode(init_state):
return odeint(f, init_state,
jnp.linspace(0.1,1.0,10),
jc.Planck15(),rtol=1e-5, atol=1e-5 )
In [12]:
res = solve_ode(init_state)
In [10]:
%timeit res = solve_ode(init_state)
In [18]:
imshow(cic_paint(jnp.zeros(mesh_shape), initial_state[0,0].numpy()).sum(axis=0))
Out[18]:
In [27]:
imshow(cic_paint(jnp.zeros(mesh_shape), res[0][-1]).sum(axis=0)); colorbar()
Out[27]:
In [28]:
imshow((cic_paint(jnp.zeros(mesh_shape), res[0][-1]) -
flowpm.cic_paint(tf.zeros([1]+mesh_shape), final_state[0]).numpy()[0]).sum(axis=0)); colorbar()
Out[28]:
In [29]:
from DifferentiableHOS.pk import power_spectrum
In [30]:
k, pk = power_spectrum(flowpm.cic_paint(tf.zeros([1]+mesh_shape), final_state[0]),
boxsize=np.array([100.,100.,100.]),
kmin=0.1,dk=2*np.pi/100.)
k, pk_jax = power_spectrum(tf.convert_to_tensor(cic_paint(jnp.zeros(mesh_shape), res[0][-1]).reshape([1,100,100,100])),
boxsize=np.array([100.,100.,100.]),
kmin=0.1,dk=2*np.pi/100.)
In [31]:
loglog(k,pk[0])
loglog(k,pk_jax[0])
Out[31]:
In [32]:
semilogx(k,(pk[0] - pk_jax[0])/pk[0])
Out[32]:
In [58]:
import diffrax
from diffrax import diffeqsolve, ODETerm, Dopri5, LeapfrogMidpoint, PIDController
cosmo = jc.Planck15()
def f_diffrax(t, state, args):
return jnp.stack(f(state, t, cosmo),axis=0)
term = ODETerm(f_diffrax)
solver = Dopri5()
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
In [66]:
solution = diffeqsolve(term, solver, t0=0.1, t1=1., dt0=0.1,
y0=jnp.stack(init_state,axis=0),
stepsize_controller=stepsize_controller)
In [67]:
solution
Out[67]:
In [68]:
resx = solution.ys[0,0]
In [69]:
imshow(cic_paint(jnp.zeros(mesh_shape), resx).sum(axis=0),cmap='gist_stern'); colorbar()
Out[69]:
In [75]:
k, pk = power_spectrum(flowpm.cic_paint(tf.zeros([1]+mesh_shape), final_state[0]),
boxsize=array([100.,100.,100.]),
kmin=0.1,dk=2*np.pi/100.)
k, pk_jax = power_spectrum(tf.convert_to_tensor(cic_paint(jnp.zeros(mesh_shape), resx).reshape([1,100,100,100])),
boxsize=array([100.,100.,100.]),
kmin=0.1,dk=2*np.pi/100.)
In [76]:
loglog(k,pk[0], label='flowpm')
loglog(k,pk_jax[0], label='jax')
legend()
Out[76]:
In [77]:
semilogx(k,(pk_jax[0] - pk[0])/pk[0])
Out[77]:
In [ ]: