fixing formatting

This commit is contained in:
Francois Lanusse 2025-06-28 23:21:31 +02:00
parent 1fcfc2aef2
commit 9d0b047f06
2 changed files with 56 additions and 51 deletions

View file

@ -3,7 +3,12 @@ from jaxpm.growth import growth_factor as Gp
from jaxpm.pm import pm_forces
def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sharding=None):
def symplectic_fpm_ode(mesh_shape,
dt0,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def drift(a, vel, args):
"""
state is a tuple (position, velocities)
@ -14,11 +19,11 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh
t1 = a + dt0
# Set the scale factors
ai = t0
ac = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
ac = (t0 * t1)**0.5 # Geometric mean of t0 and t1
af = t1
#drift_contr = (Gp(cosmo, af) - Gp(cosmo, ai)) / gp(cosmo, ac)
drift_contr = (af - ai )/ ac
drift_contr = (af - ai) / ac
# Computes the update of position (drift)
dpos = 1 / (ac**3 * E(cosmo, ac)) * vel
@ -34,27 +39,23 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh
t0 = a
t1 = t0 + dt0
t2 = t1 + dt0
t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
t1t2 = (t1 * t2) ** 0.5 # Geometric mean of t1 and t2
t0t1 = (t0 * t1)**0.5 # Geometric mean of t0 and t1
t1t2 = (t1 * t2)**0.5 # Geometric mean of t1 and t2
# Set the scale factors
ac = t1
forces = (
pm_forces(
pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding,
)
* 1.5
* cosmo.Omega_m
)
forces = (pm_forces(
pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding,
) * 1.5 * cosmo.Omega_m)
# Computes the update of velocity (kick)
dvel = 1.0 / (ac**2 * E(cosmo, ac)) * forces
# First kick control factor
kick_factor_1 = (t1 - t0t1) / t1
kick_factor_1 = (t1 - t0t1) / t1
#kick_factor_1 = (Gf(cosmo, t1) - Gf(cosmo, t0t1)) / dGfa(cosmo, t1)
# Second kick control factor
kick_factor_2 = (t2 - t1t2) / t2
@ -71,19 +72,15 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh
# Get the time steps
t0 = a
t1 = t0 + dt0
t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
t0t1 = (t0 * t1)**0.5 # Geometric mean of t0 and t1
forces = (
pm_forces(
pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding,
)
* 1.5
* cosmo.Omega_m
)
forces = (pm_forces(
pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding,
) * 1.5 * cosmo.Omega_m)
# Computes the update of velocity (kick)
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
@ -94,7 +91,12 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh
return drift, kick, first_kick
def symplectic_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None):
def symplectic_ode(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def drift(a, vel, args):
"""
state is a tuple (position, velocities)
@ -113,21 +115,17 @@ def symplectic_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=No
cosmo = args
forces = (
pm_forces(
pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding,
)
* 1.5
* cosmo.Omega_m
)
forces = (pm_forces(
pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding,
) * 1.5 * cosmo.Omega_m)
# Computes the update of velocity (kick)
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
return dvel
return drift, kick
return drift, kick