diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 32a9691..7aac781 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -74,7 +74,7 @@ def make_ode_fn(mesh_shape): def pgd_correction(pos, params): """ - improve the short-range interactions of PM-Nbody simulations with potential gradient descent method + improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671 """ kvec = fftk(mesh_shape)