forked from Aquila-Consortium/JaxPM_highres
Merge pull request #14 from DifferentiableUniverseInitiative/neural_ode
Fourier-space Neural Network scheme
This commit is contained in:
commit
0f81a89d57
2 changed files with 93 additions and 13 deletions
57
jaxpm/pm.py
57
jaxpm/pm.py
|
@ -83,7 +83,7 @@ def make_ode_fn(mesh_shape):
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
def pgd_correction(pos, params):
|
def pgd_correction(pos, mesh_shape, params):
|
||||||
"""
|
"""
|
||||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
|
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
|
||||||
args:
|
args:
|
||||||
|
@ -91,20 +91,51 @@ def pgd_correction(pos, params):
|
||||||
params: [alpha, kl, ks] pgd parameters
|
params: [alpha, kl, ks] pgd parameters
|
||||||
"""
|
"""
|
||||||
kvec = fftk(mesh_shape)
|
kvec = fftk(mesh_shape)
|
||||||
|
|
||||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||||
alpha, kl, ks = params
|
alpha, kl, ks = params
|
||||||
delta_k = jnp.fft.rfftn(delta)
|
delta_k = jnp.fft.rfftn(delta)
|
||||||
PGD_range = PGD_kernel(kvec, kl, ks)
|
PGD_range=PGD_kernel(kvec, kl, ks)
|
||||||
|
|
||||||
pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
|
pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
|
||||||
|
|
||||||
forces_pgd = jnp.stack([
|
|
||||||
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
|
||||||
for i in range(3)
|
|
||||||
],
|
|
||||||
axis=-1)
|
|
||||||
|
|
||||||
dpos_pgd = forces_pgd * alpha
|
|
||||||
|
|
||||||
|
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
|
||||||
|
for i in range(3)],axis=-1)
|
||||||
|
|
||||||
|
dpos_pgd = forces_pgd*alpha
|
||||||
|
|
||||||
return dpos_pgd
|
return dpos_pgd
|
||||||
|
|
||||||
|
|
||||||
|
def make_neural_ode_fn(model, mesh_shape):
|
||||||
|
def neural_nbody_ode(state, a, cosmo, params):
|
||||||
|
"""
|
||||||
|
state is a tuple (position, velocities)
|
||||||
|
"""
|
||||||
|
pos, vel = state
|
||||||
|
kvec = fftk(mesh_shape)
|
||||||
|
|
||||||
|
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||||
|
|
||||||
|
delta_k = jnp.fft.rfftn(delta)
|
||||||
|
|
||||||
|
# Computes gravitational potential
|
||||||
|
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
|
||||||
|
|
||||||
|
# Apply a correction filter
|
||||||
|
kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec))
|
||||||
|
pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
||||||
|
|
||||||
|
# Computes gravitational forces
|
||||||
|
forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos)
|
||||||
|
for i in range(3)],axis=-1)
|
||||||
|
|
||||||
|
forces = forces * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
|
# Computes the update of position (drift)
|
||||||
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||||
|
|
||||||
|
# Computes the update of velocity (kick)
|
||||||
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
|
return dpos, dvel
|
||||||
|
return neural_nbody_ode
|
||||||
|
|
|
@ -83,6 +83,55 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
||||||
return kbins, P / norm
|
return kbins, P / norm
|
||||||
|
|
||||||
|
|
||||||
|
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
|
||||||
|
"""
|
||||||
|
Calculate the cross correlation coefficients given two real space field
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
field_a: real valued field
|
||||||
|
field_b: real valued field
|
||||||
|
kmin: minimum k-value for binned powerspectra
|
||||||
|
dk: differential in each kbin
|
||||||
|
boxsize: length of each boxlength (can be strangly shaped?)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
kbins: the central value of the bins for plotting
|
||||||
|
P / norm: normalized cross correlation coefficient between two field a and b
|
||||||
|
|
||||||
|
"""
|
||||||
|
shape = field_a.shape
|
||||||
|
nx, ny, nz = shape
|
||||||
|
|
||||||
|
#initialze values related to powerspectra (mode bins and weights)
|
||||||
|
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
||||||
|
|
||||||
|
#fast fourier transform
|
||||||
|
fft_image_a = jnp.fft.fftn(field_a)
|
||||||
|
fft_image_b = jnp.fft.fftn(field_b)
|
||||||
|
|
||||||
|
#absolute value of fast fourier transform
|
||||||
|
pk = fft_image_a * jnp.conj(fft_image_b)
|
||||||
|
|
||||||
|
#calculating powerspectra
|
||||||
|
real = jnp.real(pk).reshape([-1])
|
||||||
|
imag = jnp.imag(pk).reshape([-1])
|
||||||
|
|
||||||
|
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
|
||||||
|
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
||||||
|
|
||||||
|
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
||||||
|
|
||||||
|
#normalization for powerspectra
|
||||||
|
norm = np.prod(np.array(shape[:])).astype('float32')**2
|
||||||
|
|
||||||
|
#find central values of each bin
|
||||||
|
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
||||||
|
|
||||||
|
return kbins, P / norm
|
||||||
|
|
||||||
|
|
||||||
def gaussian_smoothing(im, sigma):
|
def gaussian_smoothing(im, sigma):
|
||||||
"""
|
"""
|
||||||
im: 2d image
|
im: 2d image
|
||||||
|
|
Loading…
Add table
Reference in a new issue