diff --git a/.all-contributorsrc b/.all-contributorsrc index f763ece..bef516f 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -13,6 +13,15 @@ "contributions": [ "ideas" ] + }, + { + "login": "dlanzieri", + "name": "Denise Lanzieri", + "avatar_url": "https://avatars.githubusercontent.com/u/72620117?v=4", + "profile": "https://github.com/dlanzieri", + "contributions": [ + "code" + ] } ], "contributorsPerLine": 7, @@ -20,5 +29,7 @@ "projectOwner": "DifferentiableUniverseInitiative", "repoType": "github", "repoHost": "https://github.com", - "skipCi": true + "skipCi": true, + "commitType": "docs", + "commitConvention": "angular" } diff --git a/README.md b/README.md index 5a11d45..94941ca 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # JaxPM -[![All Contributors](https://img.shields.io/badge/all_contributors-1-orange.svg?style=flat-square)](#contributors-) [![Join the chat at https://gitter.im/DifferentiableUniverseInitiative/JaxPM](https://badges.gitter.im/DifferentiableUniverseInitiative/JaxPM.svg)](https://gitter.im/DifferentiableUniverseInitiative/JaxPM?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![All Contributors](https://img.shields.io/badge/all_contributors-2-orange.svg?style=flat-square)](#contributors-) JAX-powered Cosmological Particle-Mesh N-body Solver @@ -32,9 +32,12 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d - - - + + + + + +

Francois Lanusse

🤔
Francois Lanusse
Francois Lanusse

🤔
Denise Lanzieri
Denise Lanzieri

💻
diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 79080df..5f370d8 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -172,7 +172,7 @@ def make_ode_fn(mesh_shape, halo_size=0): return nbody_ode -def pgd_correction(pos, params): +def pgd_correction(pos, mesh_shape, params): """ improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671 args: @@ -180,20 +180,51 @@ def pgd_correction(pos, params): params: [alpha, kl, ks] pgd parameters """ kvec = fftk(mesh_shape) - delta = cic_paint(jnp.zeros(mesh_shape), pos) alpha, kl, ks = params delta_k = jnp.fft.rfftn(delta) - PGD_range = PGD_kernel(kvec, kl, ks) - - pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range - - forces_pgd = jnp.stack([ - cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos) - for i in range(3) - ], - axis=-1) - - dpos_pgd = forces_pgd * alpha + PGD_range=PGD_kernel(kvec, kl, ks) + + pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range + forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos) + for i in range(3)],axis=-1) + + dpos_pgd = forces_pgd*alpha + return dpos_pgd + + +def make_neural_ode_fn(model, mesh_shape): + def neural_nbody_ode(state, a, cosmo, params): + """ + state is a tuple (position, velocities) + """ + pos, vel = state + kvec = fftk(mesh_shape) + + delta = cic_paint(jnp.zeros(mesh_shape), pos) + + delta_k = jnp.fft.rfftn(delta) + + # Computes gravitational potential + pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0) + + # Apply a correction filter + kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec)) + pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a))) + + # Computes gravitational forces + forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos) + for i in range(3)],axis=-1) + + forces = forces * 1.5 * cosmo.Omega_m + + # Computes the update of position (drift) + dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel + + # Computes the update of velocity (kick) + dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces + + return dpos, dvel + return neural_nbody_ode diff --git a/jaxpm/utils.py b/jaxpm/utils.py index fc00a79..1593ba0 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -83,6 +83,55 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): return kbins, P / norm +def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False): + """ + Calculate the cross correlation coefficients given two real space field + + Args: + + field_a: real valued field + field_b: real valued field + kmin: minimum k-value for binned powerspectra + dk: differential in each kbin + boxsize: length of each boxlength (can be strangly shaped?) + + Returns: + + kbins: the central value of the bins for plotting + P / norm: normalized cross correlation coefficient between two field a and b + + """ + shape = field_a.shape + nx, ny, nz = shape + + #initialze values related to powerspectra (mode bins and weights) + dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) + + #fast fourier transform + fft_image_a = jnp.fft.fftn(field_a) + fft_image_b = jnp.fft.fftn(field_b) + + #absolute value of fast fourier transform + pk = fft_image_a * jnp.conj(fft_image_b) + + #calculating powerspectra + real = jnp.real(pk).reshape([-1]) + imag = jnp.imag(pk).reshape([-1]) + + Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j + Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) + + P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') + + #normalization for powerspectra + norm = np.prod(np.array(shape[:])).astype('float32')**2 + + #find central values of each bin + kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 + + return kbins, P / norm + + def gaussian_smoothing(im, sigma): """ im: 2d image