mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-09 04:50:55 +00:00
fix formatting from main
This commit is contained in:
parent
02754cf452
commit
8da3149581
2 changed files with 57 additions and 43 deletions
19
jaxpm/pm.py
19
jaxpm/pm.py
|
@ -187,8 +187,11 @@ def pgd_correction(pos, mesh_shape, params):
|
|||
|
||||
pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
|
||||
|
||||
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
|
||||
for i in range(3)],axis=-1)
|
||||
forces_pgd = jnp.stack([
|
||||
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
dpos_pgd = forces_pgd * alpha
|
||||
|
||||
|
@ -196,6 +199,7 @@ def pgd_correction(pos, mesh_shape, params):
|
|||
|
||||
|
||||
def make_neural_ode_fn(model, mesh_shape):
|
||||
|
||||
def neural_nbody_ode(state, a, cosmo, params):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
|
@ -208,15 +212,19 @@ def make_neural_ode_fn(model, mesh_shape):
|
|||
delta_k = jnp.fft.rfftn(delta)
|
||||
|
||||
# Computes gravitational potential
|
||||
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
|
||||
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
|
||||
r_split=0)
|
||||
|
||||
# Apply a correction filter
|
||||
kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
|
||||
pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
||||
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos)
|
||||
for i in range(3)],axis=-1)
|
||||
forces = jnp.stack([
|
||||
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), pos)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
forces = forces * 1.5 * cosmo.Omega_m
|
||||
|
||||
|
@ -227,4 +235,5 @@ def make_neural_ode_fn(model, mesh_shape):
|
|||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
return dpos, dvel
|
||||
|
||||
return neural_nbody_ode
|
||||
|
|
|
@ -83,7 +83,11 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
|||
return kbins, P / norm
|
||||
|
||||
|
||||
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
|
||||
def cross_correlation_coefficients(field_a,
|
||||
field_b,
|
||||
kmin=5,
|
||||
dk=0.5,
|
||||
boxsize=False):
|
||||
"""
|
||||
Calculate the cross correlation coefficients given two real space field
|
||||
|
||||
|
@ -118,7 +122,8 @@ def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=Fals
|
|||
real = jnp.real(pk).reshape([-1])
|
||||
imag = jnp.imag(pk).reshape([-1])
|
||||
|
||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
|
||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
|
||||
length=xsum.size) * 1j
|
||||
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
||||
|
||||
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
||||
|
|
Loading…
Add table
Reference in a new issue