fix formatting from main

This commit is contained in:
Wassim KABALAN 2024-08-03 00:45:35 +02:00
parent 02754cf452
commit 8da3149581
2 changed files with 57 additions and 43 deletions

View file

@ -183,19 +183,23 @@ def pgd_correction(pos, mesh_shape, params):
delta = cic_paint(jnp.zeros(mesh_shape), pos)
alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta)
PGD_range=PGD_kernel(kvec, kl, ks)
pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
PGD_range = PGD_kernel(kvec, kl, ks)
pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
forces_pgd = jnp.stack([
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
for i in range(3)
],
axis=-1)
dpos_pgd = forces_pgd * alpha
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
for i in range(3)],axis=-1)
dpos_pgd = forces_pgd*alpha
return dpos_pgd
def make_neural_ode_fn(model, mesh_shape):
def neural_nbody_ode(state, a, cosmo, params):
"""
state is a tuple (position, velocities)
@ -208,15 +212,19 @@ def make_neural_ode_fn(model, mesh_shape):
delta_k = jnp.fft.rfftn(delta)
# Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
r_split=0)
# Apply a correction filter
kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec))
pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a)))
kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
# Computes gravitational forces
forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos)
for i in range(3)],axis=-1)
forces = jnp.stack([
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), pos)
for i in range(3)
],
axis=-1)
forces = forces * 1.5 * cosmo.Omega_m
@ -227,4 +235,5 @@ def make_neural_ode_fn(model, mesh_shape):
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel
return neural_nbody_ode

View file

@ -83,53 +83,58 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
return kbins, P / norm
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
"""
def cross_correlation_coefficients(field_a,
field_b,
kmin=5,
dk=0.5,
boxsize=False):
"""
Calculate the cross correlation coefficients given two real space field
Args:
field_a: real valued field
field_b: real valued field
field_a: real valued field
field_b: real valued field
kmin: minimum k-value for binned powerspectra
dk: differential in each kbin
boxsize: length of each boxlength (can be strangly shaped?)
Returns:
kbins: the central value of the bins for plotting
P / norm: normalized cross correlation coefficient between two field a and b
P / norm: normalized cross correlation coefficient between two field a and b
"""
shape = field_a.shape
nx, ny, nz = shape
shape = field_a.shape
nx, ny, nz = shape
#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#fast fourier transform
fft_image_a = jnp.fft.fftn(field_a)
fft_image_b = jnp.fft.fftn(field_b)
#fast fourier transform
fft_image_a = jnp.fft.fftn(field_a)
fft_image_b = jnp.fft.fftn(field_b)
#absolute value of fast fourier transform
pk = fft_image_a * jnp.conj(fft_image_b)
#absolute value of fast fourier transform
pk = fft_image_a * jnp.conj(fft_image_b)
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#find central values of each bin
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
#find central values of each bin
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
return kbins, P / norm
return kbins, P / norm
def gaussian_smoothing(im, sigma):