Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp_proto

This commit is contained in:
Wassim KABALAN 2024-08-03 00:23:54 +02:00
commit 783a97423e
4 changed files with 112 additions and 18 deletions

View file

@ -13,6 +13,15 @@
"contributions": [
"ideas"
]
},
{
"login": "dlanzieri",
"name": "Denise Lanzieri",
"avatar_url": "https://avatars.githubusercontent.com/u/72620117?v=4",
"profile": "https://github.com/dlanzieri",
"contributions": [
"code"
]
}
],
"contributorsPerLine": 7,
@ -20,5 +29,7 @@
"projectOwner": "DifferentiableUniverseInitiative",
"repoType": "github",
"repoHost": "https://github.com",
"skipCi": true
"skipCi": true,
"commitType": "docs",
"commitConvention": "angular"
}

View file

@ -1,6 +1,6 @@
# JaxPM
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-1-orange.svg?style=flat-square)](#contributors-) [![Join the chat at https://gitter.im/DifferentiableUniverseInitiative/JaxPM](https://badges.gitter.im/DifferentiableUniverseInitiative/JaxPM.svg)](https://gitter.im/DifferentiableUniverseInitiative/JaxPM?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![All Contributors](https://img.shields.io/badge/all_contributors-2-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->
JAX-powered Cosmological Particle-Mesh N-body Solver
@ -32,9 +32,12 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<!-- prettier-ignore-start -->
<!-- markdownlint-disable -->
<table>
<tr>
<td align="center"><a href="http://flanusse.net"><img src="https://avatars.githubusercontent.com/u/861591?v=4?s=100" width="100px;" alt=""/><br /><sub><b>Francois Lanusse</b></sub></a><br /><a href="#ideas-EiffL" title="Ideas, Planning, & Feedback">🤔</a></td>
</tr>
<tbody>
<tr>
<td align="center" valign="top" width="14.28%"><a href="http://flanusse.net"><img src="https://avatars.githubusercontent.com/u/861591?v=4?s=100" width="100px;" alt="Francois Lanusse"/><br /><sub><b>Francois Lanusse</b></sub></a><br /><a href="#ideas-EiffL" title="Ideas, Planning, & Feedback">🤔</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/dlanzieri"><img src="https://avatars.githubusercontent.com/u/72620117?v=4?s=100" width="100px;" alt="Denise Lanzieri"/><br /><sub><b>Denise Lanzieri</b></sub></a><br /><a href="https://github.com/DifferentiableUniverseInitiative/JaxPM/commits?author=dlanzieri" title="Code">💻</a></td>
</tr>
</tbody>
</table>
<!-- markdownlint-restore -->

View file

@ -172,7 +172,7 @@ def make_ode_fn(mesh_shape, halo_size=0):
return nbody_ode
def pgd_correction(pos, params):
def pgd_correction(pos, mesh_shape, params):
"""
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
args:
@ -180,20 +180,51 @@ def pgd_correction(pos, params):
params: [alpha, kl, ks] pgd parameters
"""
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos)
alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta)
PGD_range = PGD_kernel(kvec, kl, ks)
pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
forces_pgd = jnp.stack([
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
for i in range(3)
],
axis=-1)
dpos_pgd = forces_pgd * alpha
PGD_range=PGD_kernel(kvec, kl, ks)
pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
for i in range(3)],axis=-1)
dpos_pgd = forces_pgd*alpha
return dpos_pgd
def make_neural_ode_fn(model, mesh_shape):
def neural_nbody_ode(state, a, cosmo, params):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos)
delta_k = jnp.fft.rfftn(delta)
# Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
# Apply a correction filter
kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec))
pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a)))
# Computes gravitational forces
forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos)
for i in range(3)],axis=-1)
forces = forces * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel
return neural_nbody_ode

View file

@ -83,6 +83,55 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
return kbins, P / norm
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
"""
Calculate the cross correlation coefficients given two real space field
Args:
field_a: real valued field
field_b: real valued field
kmin: minimum k-value for binned powerspectra
dk: differential in each kbin
boxsize: length of each boxlength (can be strangly shaped?)
Returns:
kbins: the central value of the bins for plotting
P / norm: normalized cross correlation coefficient between two field a and b
"""
shape = field_a.shape
nx, ny, nz = shape
#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#fast fourier transform
fft_image_a = jnp.fft.fftn(field_a)
fft_image_b = jnp.fft.fftn(field_b)
#absolute value of fast fourier transform
pk = fft_image_a * jnp.conj(fft_image_b)
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#find central values of each bin
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
return kbins, P / norm
def gaussian_smoothing(im, sigma):
"""
im: 2d image