From 8da3149581b85e5cc24f69dd50fa92099596e4bb Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Sat, 3 Aug 2024 00:45:35 +0200 Subject: [PATCH] fix formatting from main --- jaxpm/pm.py | 35 +++++++++++++++++---------- jaxpm/utils.py | 65 +++++++++++++++++++++++++++----------------------- 2 files changed, 57 insertions(+), 43 deletions(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 5f370d8..377df8e 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -183,19 +183,23 @@ def pgd_correction(pos, mesh_shape, params): delta = cic_paint(jnp.zeros(mesh_shape), pos) alpha, kl, ks = params delta_k = jnp.fft.rfftn(delta) - PGD_range=PGD_kernel(kvec, kl, ks) - - pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range + PGD_range = PGD_kernel(kvec, kl, ks) + + pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range + + forces_pgd = jnp.stack([ + cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos) + for i in range(3) + ], + axis=-1) + + dpos_pgd = forces_pgd * alpha - forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos) - for i in range(3)],axis=-1) - - dpos_pgd = forces_pgd*alpha - return dpos_pgd def make_neural_ode_fn(model, mesh_shape): + def neural_nbody_ode(state, a, cosmo, params): """ state is a tuple (position, velocities) @@ -208,15 +212,19 @@ def make_neural_ode_fn(model, mesh_shape): delta_k = jnp.fft.rfftn(delta) # Computes gravitational potential - pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0) + pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, + r_split=0) # Apply a correction filter - kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec)) - pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a))) + kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec)) + pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a))) # Computes gravitational forces - forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos) - for i in range(3)],axis=-1) + forces = jnp.stack([ + cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), pos) + for i in range(3) + ], + axis=-1) forces = forces * 1.5 * cosmo.Omega_m @@ -227,4 +235,5 @@ def make_neural_ode_fn(model, mesh_shape): dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces return dpos, dvel + return neural_nbody_ode diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 1593ba0..7c6af44 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -83,53 +83,58 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): return kbins, P / norm -def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False): - """ +def cross_correlation_coefficients(field_a, + field_b, + kmin=5, + dk=0.5, + boxsize=False): + """ Calculate the cross correlation coefficients given two real space field - + Args: - - field_a: real valued field - field_b: real valued field + + field_a: real valued field + field_b: real valued field kmin: minimum k-value for binned powerspectra dk: differential in each kbin boxsize: length of each boxlength (can be strangly shaped?) - + Returns: - + kbins: the central value of the bins for plotting - P / norm: normalized cross correlation coefficient between two field a and b - + P / norm: normalized cross correlation coefficient between two field a and b + """ - shape = field_a.shape - nx, ny, nz = shape + shape = field_a.shape + nx, ny, nz = shape - #initialze values related to powerspectra (mode bins and weights) - dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) + #initialze values related to powerspectra (mode bins and weights) + dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) - #fast fourier transform - fft_image_a = jnp.fft.fftn(field_a) - fft_image_b = jnp.fft.fftn(field_b) + #fast fourier transform + fft_image_a = jnp.fft.fftn(field_a) + fft_image_b = jnp.fft.fftn(field_b) - #absolute value of fast fourier transform - pk = fft_image_a * jnp.conj(fft_image_b) + #absolute value of fast fourier transform + pk = fft_image_a * jnp.conj(fft_image_b) - #calculating powerspectra - real = jnp.real(pk).reshape([-1]) - imag = jnp.imag(pk).reshape([-1]) + #calculating powerspectra + real = jnp.real(pk).reshape([-1]) + imag = jnp.imag(pk).reshape([-1]) - Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j - Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) + Psum = jnp.bincount(dig, weights=(W.flatten() * imag), + length=xsum.size) * 1j + Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) - P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') + P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') - #normalization for powerspectra - norm = np.prod(np.array(shape[:])).astype('float32')**2 + #normalization for powerspectra + norm = np.prod(np.array(shape[:])).astype('float32')**2 - #find central values of each bin - kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 + #find central values of each bin + kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 - return kbins, P / norm + return kbins, P / norm def gaussian_smoothing(im, sigma):