diff --git a/.all-contributorsrc b/.all-contributorsrc
index f763ece..bef516f 100644
--- a/.all-contributorsrc
+++ b/.all-contributorsrc
@@ -13,6 +13,15 @@
"contributions": [
"ideas"
]
+ },
+ {
+ "login": "dlanzieri",
+ "name": "Denise Lanzieri",
+ "avatar_url": "https://avatars.githubusercontent.com/u/72620117?v=4",
+ "profile": "https://github.com/dlanzieri",
+ "contributions": [
+ "code"
+ ]
}
],
"contributorsPerLine": 7,
@@ -20,5 +29,7 @@
"projectOwner": "DifferentiableUniverseInitiative",
"repoType": "github",
"repoHost": "https://github.com",
- "skipCi": true
+ "skipCi": true,
+ "commitType": "docs",
+ "commitConvention": "angular"
}
diff --git a/README.md b/README.md
index 5a11d45..94941ca 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# JaxPM
-[](#contributors-) [](https://gitter.im/DifferentiableUniverseInitiative/JaxPM?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
+[](#contributors-)
JAX-powered Cosmological Particle-Mesh N-body Solver
@@ -32,9 +32,12 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
diff --git a/jaxpm/pm.py b/jaxpm/pm.py
index 79080df..5f370d8 100644
--- a/jaxpm/pm.py
+++ b/jaxpm/pm.py
@@ -172,7 +172,7 @@ def make_ode_fn(mesh_shape, halo_size=0):
return nbody_ode
-def pgd_correction(pos, params):
+def pgd_correction(pos, mesh_shape, params):
"""
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
args:
@@ -180,20 +180,51 @@ def pgd_correction(pos, params):
params: [alpha, kl, ks] pgd parameters
"""
kvec = fftk(mesh_shape)
-
delta = cic_paint(jnp.zeros(mesh_shape), pos)
alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta)
- PGD_range = PGD_kernel(kvec, kl, ks)
-
- pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
-
- forces_pgd = jnp.stack([
- cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
- for i in range(3)
- ],
- axis=-1)
-
- dpos_pgd = forces_pgd * alpha
+ PGD_range=PGD_kernel(kvec, kl, ks)
+
+ pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
+ forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
+ for i in range(3)],axis=-1)
+
+ dpos_pgd = forces_pgd*alpha
+
return dpos_pgd
+
+
+def make_neural_ode_fn(model, mesh_shape):
+ def neural_nbody_ode(state, a, cosmo, params):
+ """
+ state is a tuple (position, velocities)
+ """
+ pos, vel = state
+ kvec = fftk(mesh_shape)
+
+ delta = cic_paint(jnp.zeros(mesh_shape), pos)
+
+ delta_k = jnp.fft.rfftn(delta)
+
+ # Computes gravitational potential
+ pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
+
+ # Apply a correction filter
+ kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec))
+ pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a)))
+
+ # Computes gravitational forces
+ forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos)
+ for i in range(3)],axis=-1)
+
+ forces = forces * 1.5 * cosmo.Omega_m
+
+ # Computes the update of position (drift)
+ dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
+
+ # Computes the update of velocity (kick)
+ dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
+
+ return dpos, dvel
+ return neural_nbody_ode
diff --git a/jaxpm/utils.py b/jaxpm/utils.py
index fc00a79..1593ba0 100644
--- a/jaxpm/utils.py
+++ b/jaxpm/utils.py
@@ -83,6 +83,55 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
return kbins, P / norm
+def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
+ """
+ Calculate the cross correlation coefficients given two real space field
+
+ Args:
+
+ field_a: real valued field
+ field_b: real valued field
+ kmin: minimum k-value for binned powerspectra
+ dk: differential in each kbin
+ boxsize: length of each boxlength (can be strangly shaped?)
+
+ Returns:
+
+ kbins: the central value of the bins for plotting
+ P / norm: normalized cross correlation coefficient between two field a and b
+
+ """
+ shape = field_a.shape
+ nx, ny, nz = shape
+
+ #initialze values related to powerspectra (mode bins and weights)
+ dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
+
+ #fast fourier transform
+ fft_image_a = jnp.fft.fftn(field_a)
+ fft_image_b = jnp.fft.fftn(field_b)
+
+ #absolute value of fast fourier transform
+ pk = fft_image_a * jnp.conj(fft_image_b)
+
+ #calculating powerspectra
+ real = jnp.real(pk).reshape([-1])
+ imag = jnp.imag(pk).reshape([-1])
+
+ Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
+ Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
+
+ P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
+
+ #normalization for powerspectra
+ norm = np.prod(np.array(shape[:])).astype('float32')**2
+
+ #find central values of each bin
+ kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
+
+ return kbins, P / norm
+
+
def gaussian_smoothing(im, sigma):
"""
im: 2d image