From 5b72a734484e23773f9f69aaf13e5fbf506f8c89 Mon Sep 17 00:00:00 2001 From: denise lanzieri Date: Sat, 11 Jun 2022 14:28:30 +0200 Subject: [PATCH 1/6] neural ode added --- jaxpm/pm.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index d9870f7..d54a252 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -93,4 +93,40 @@ def pgd_correction(pos, params): dpos_pgd = forces_pgd*alpha - return dpos_pgd \ No newline at end of file + return dpos_pgd + + +def make_neural_ode_fn(model, mesh_shape): + def neural_nbody_ode(state, a, cosmo, params): + """ + state is a tuple (position, velocities) + """ + pos, vel = state + kvec = fftk(mesh_shape) + + delta = cic_paint(jnp.zeros(mesh_shape), pos) + + delta_k = jnp.fft.rfftn(delta) + + # Computes gravitational potential + pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0) + + # Apply a correction filter + kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec)) + pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a))) + + # Computes gravitational forces + forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos) + for i in range(3)],axis=-1) + + forces = forces * 1.5 * cosmo.Omega_m + + # Computes the update of position (drift) + dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel + + # Computes the update of velocity (kick) + dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces + + return dpos, dvel + return neural_nbody_ode + From f185b8b91396786bdf9e40b09a62a75536955f4d Mon Sep 17 00:00:00 2001 From: denise lanzieri Date: Mon, 13 Jun 2022 17:17:19 +0200 Subject: [PATCH 2/6] few adjustments to PGD correction --- jaxpm/pm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index d54a252..231a89b 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -72,7 +72,7 @@ def make_ode_fn(mesh_shape): return nbody_ode -def pgd_correction(pos, params): +def pgd_correction(pos, mesh_shape, cosmo, params): """ improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671 args: From b26cfcd01a8e1a79f03f9d60ac10d282666cd110 Mon Sep 17 00:00:00 2001 From: denise lanzieri Date: Sat, 18 Jun 2022 18:23:46 +0200 Subject: [PATCH 3/6] creoss correlation function --- jaxpm/pm.py | 1 - jaxpm/utils.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 231a89b..8e9e052 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -80,7 +80,6 @@ def pgd_correction(pos, mesh_shape, cosmo, params): params: [alpha, kl, ks] pgd parameters """ kvec = fftk(mesh_shape) - delta = cic_paint(jnp.zeros(mesh_shape), pos) alpha, kl, ks = params delta_k = jnp.fft.rfftn(delta) diff --git a/jaxpm/utils.py b/jaxpm/utils.py index a01e188..0249174 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -81,6 +81,55 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): return kbins, P / norm +def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False): + """ + Calculate the cross correlation coefficients given two real space field + + Args: + + field_a: real valued field + field_b: real valued field + kmin: minimum k-value for binned powerspectra + dk: differential in each kbin + boxsize: length of each boxlength (can be strangly shaped?) + + Returns: + + kbins: the central value of the bins for plotting + P / norm: normalized cross correlation coefficient between two field a and b + + """ + shape = field_a.shape + nx, ny, nz = shape + + #initialze values related to powerspectra (mode bins and weights) + dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) + + #fast fourier transform + fft_image_a = jnp.fft.fftn(field_a) + fft_image_b = jnp.fft.fftn(field_b) + + #absolute value of fast fourier transform + pk = fft_image_a * jnp.conj(fft_image_b) + + #calculating powerspectra + real = jnp.real(pk).reshape([-1]) + imag = jnp.imag(pk).reshape([-1]) + + Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j + Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) + + P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') + + #normalization for powerspectra + norm = np.prod(np.array(shape[:])).astype('float32')**2 + + #find central values of each bin + kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 + + return kbins, P / norm + + def gaussian_smoothing(im, sigma): """ im: 2d image From 9cbdf18932b8e728842223f068057bc8c98a6d20 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Fri, 19 Jul 2024 10:49:51 -0400 Subject: [PATCH 4/6] Update jaxpm/pm.py --- jaxpm/pm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 686cc07..9b14a87 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -83,7 +83,7 @@ def make_ode_fn(mesh_shape): return nbody_ode -def pgd_correction(pos, mesh_shape, cosmo, params): +def pgd_correction(pos, mesh_shape, params): """ improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671 args: From 38270f3358101d7c186e4eaf012440f594455570 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:51:25 +0000 Subject: [PATCH 5/6] docs: update README.md [skip ci] --- README.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 5a11d45..94941ca 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # JaxPM -[![All Contributors](https://img.shields.io/badge/all_contributors-1-orange.svg?style=flat-square)](#contributors-) [![Join the chat at https://gitter.im/DifferentiableUniverseInitiative/JaxPM](https://badges.gitter.im/DifferentiableUniverseInitiative/JaxPM.svg)](https://gitter.im/DifferentiableUniverseInitiative/JaxPM?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![All Contributors](https://img.shields.io/badge/all_contributors-2-orange.svg?style=flat-square)](#contributors-) JAX-powered Cosmological Particle-Mesh N-body Solver @@ -32,9 +32,12 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d - - - + + + + + +

Francois Lanusse

🤔
Francois Lanusse
Francois Lanusse

🤔
Denise Lanzieri
Denise Lanzieri

💻
From b4adc119b6f84244f0f2e4fb0cb4f708f49aadb1 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:51:26 +0000 Subject: [PATCH 6/6] docs: update .all-contributorsrc [skip ci] --- .all-contributorsrc | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index f763ece..bef516f 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -13,6 +13,15 @@ "contributions": [ "ideas" ] + }, + { + "login": "dlanzieri", + "name": "Denise Lanzieri", + "avatar_url": "https://avatars.githubusercontent.com/u/72620117?v=4", + "profile": "https://github.com/dlanzieri", + "contributions": [ + "code" + ] } ], "contributorsPerLine": 7, @@ -20,5 +29,7 @@ "projectOwner": "DifferentiableUniverseInitiative", "repoType": "github", "repoHost": "https://github.com", - "skipCi": true + "skipCi": true, + "commitType": "docs", + "commitConvention": "angular" }