diff --git a/jaxpm/growth.py b/jaxpm/growth.py index ec248f3..cb5aa82 100644 --- a/jaxpm/growth.py +++ b/jaxpm/growth.py @@ -119,7 +119,7 @@ def growth_factor(cosmo, a): if cosmo._flags["gamma_growth"]: return _growth_factor_gamma(cosmo, a) else: - return _growth_factor_ODE(cosmo, a) + return _growth_factor_ODE(cosmo, a)[0] def growth_factor_second(cosmo, a): @@ -225,7 +225,7 @@ def growth_rate_second(cosmo, a): return _growth_rate_second_ODE(cosmo, a) -def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): +def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4): """Compute linear growth factor D(a) at a given scale factor, normalised such that D(a=1) = 1. @@ -243,57 +243,56 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): Growth factor computed at requested scale factor """ # Check if growth has already been computed - if not "background.growth_factor" in cosmo._workspace.keys(): - # Compute tabulated array - atab = np.logspace(log10_amin, 0.0, steps) + #if not "background.growth_factor" in cosmo._workspace.keys(): + # Compute tabulated array + atab = np.logspace(log10_amin, 0.0, steps) - def D_derivs(y, x): - q = (2.0 - 0.5 * - (Omega_m_a(cosmo, x) + - (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x - r = 1.5 * Omega_m_a(cosmo, x) / x / x + def D_derivs(y, x): + q = (2.0 - 0.5 * + (Omega_m_a(cosmo, x) + + (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x + r = 1.5 * Omega_m_a(cosmo, x) / x / x - g1, g2 = y[0] - f1, f2 = y[1] - dy1da = [f1, -q * f1 + r * g1] - dy2da = [f2, -q * f2 + r * g2 - r * g1**2] - return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]]) + g1, g2 = y[0] + f1, f2 = y[1] + dy1da = [f1, -q * f1 + r * g1] + dy2da = [f2, -q * f2 + r * g2 - r * g1**2] + return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]]) - y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2], - [1.0, -6.0 / 7 * atab[0]]]) - y = odeint(D_derivs, y0, atab) + y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2], + [1.0, -6.0 / 7 * atab[0]]]) + y = odeint(D_derivs, y0, atab) - # compute second order derivatives growth - dyda2 = D_derivs(np.transpose(y, (1, 2, 0)), atab) - dyda2 = np.transpose(dyda2, (2, 0, 1)) + # compute second order derivatives growth + dyda2 = D_derivs(np.transpose(y, (1, 2, 0)), atab) + dyda2 = np.transpose(dyda2, (2, 0, 1)) - # Normalize results - y1 = y[:, 0, 0] - gtab = y1 / y1[-1] - y2 = y[:, 0, 1] - g2tab = y2 / y2[-1] - # To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da - ftab = y[:, 1, 0] / y1[-1] * atab / gtab - f2tab = y[:, 1, 1] / y2[-1] * atab / g2tab - # Similarly for second order derivatives - # Note: these factors are not accessible as parent functions yet - # since it is unclear what to refer to them with. - htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab - h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab + # Normalize results + y1 = y[:, 0, 0] + gtab = y1 / y1[-1] + y2 = y[:, 0, 1] + g2tab = y2 / y2[-1] + # To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da + ftab = y[:, 1, 0] / y1[-1] * atab / gtab + f2tab = y[:, 1, 1] / y2[-1] * atab / g2tab + # Similarly for second order derivatives + # Note: these factors are not accessible as parent functions yet + # since it is unclear what to refer to them with. + htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab + h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab - cache = { - "a": atab, - "g": gtab, - "f": ftab, - "h": htab, - "g2": g2tab, - "f2": f2tab, - "h2": h2tab, - } - cosmo._workspace["background.growth_factor"] = cache - else: - cache = cosmo._workspace["background.growth_factor"] - return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) + + cache = { + "a": atab, + "g": gtab, + "f": ftab, + "h": htab, + "g2": g2tab, + "f2": f2tab, + "h2": h2tab, + } + + return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) , cache def _growth_rate_ODE(cosmo, a): @@ -314,12 +313,10 @@ def _growth_rate_ODE(cosmo, a): Growth rate computed at requested scale factor """ # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): - _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) - cache = cosmo._workspace["background.growth_factor"] + + cache = _growth_factor_ODE(cosmo, np.atleast_1d(1.0))[1] return interp(a, cache["a"], cache["f"]) - def _growth_factor_second_ODE(cosmo, a): """Compute second order growth factor D2(a) at a given scale factor, normalised such that D(a=1) = 1. @@ -338,36 +335,12 @@ def _growth_factor_second_ODE(cosmo, a): Second order growth factor computed at requested scale factor """ # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): - _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) - cache = cosmo._workspace["background.growth_factor"] + #if not "background.growth_factor" in cosmo._workspace.keys(): + # _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) + cache = _growth_factor_ODE(cosmo, a)[1] return interp(a, cache["a"], cache["g2"]) -def _growth_rate_ODE(cosmo, a): - """Compute growth rate dD/dlna at a given scale factor by solving the linear - growth ODE. - - Parameters - ---------- - cosmo: `Cosmology` - Cosmology object - - a: array_like - Scale factor - - Returns - ------- - f: ndarray, or float if input scalar - Second order growth rate computed at requested scale factor - """ - # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): - _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) - cache = cosmo._workspace["background.growth_factor"] - return interp(a, cache["a"], cache["f"]) - - def _growth_rate_second_ODE(cosmo, a): """Compute second order growth rate dD2/dlna at a given scale factor by solving the linear growth ODE. @@ -386,9 +359,9 @@ def _growth_rate_second_ODE(cosmo, a): Second order growth rate computed at requested scale factor """ # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): - _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) - cache = cosmo._workspace["background.growth_factor"] + #if not "background.growth_factor" in cosmo._workspace.keys(): + # _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) + cache = _growth_factor_ODE(cosmo, a)[1] return interp(a, cache["a"], cache["f2"]) @@ -521,6 +494,34 @@ def Gf2(cosmo, a): return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) +def gp(cosmo, a): + r""" Derivative of D1 against a + + Parameters + ---------- + cosmo: dict + Cosmology dictionary. + + a : array_like + Scale factor. + + Returns + ------- + Scalar float Tensor : the derivative of D1 against a. + + Notes + ----- + + The expression for :math:`gp(a)` is: + + .. math:: + gp(a)=\frac{dD1}{da}= D'_{1norm}/a + """ + f1 = growth_rate(cosmo, a) + g1 = growth_factor(cosmo, a) + D1f = f1 * g1 / a + return D1f + def dGfa(cosmo, a): r""" Derivative of Gf against a @@ -549,7 +550,8 @@ def dGfa(cosmo, a): f1 = growth_rate(cosmo, a) g1 = growth_factor(cosmo, a) D1f = f1 * g1 / a - cache = cosmo._workspace['background.growth_factor'] + #cache = cosmo._workspace['background.growth_factor'] + cache = _growth_factor_ODE(cosmo, a)[1] f1p = cache['h'] / cache['a'] * cache['g'] f1p = interp(np.log(a), np.log(cache['a']), f1p) Ea = E(cosmo, a) @@ -584,9 +586,10 @@ def dGf2a(cosmo, a): f2 = growth_rate_second(cosmo, a) g2 = growth_factor_second(cosmo, a) D2f = f2 * g2 / a - cache = cosmo._workspace['background.growth_factor'] + #cache = cosmo._workspace['background.growth_factor'] + cache = _growth_factor_ODE(cosmo, a)[1] f2p = cache['h2'] / cache['a'] * cache['g2'] f2p = interp(np.log(a), np.log(cache['a']), f2p) E_a = E(cosmo, a) return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) + - 3 * a**2 * E_a * D2f) + 3 * a**2 * E_a * D2f) \ No newline at end of file diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 3083f08..f3c50df 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk from jaxpm.painting_utils import gather, scatter -def _cic_paint_impl(grid_mesh, positions, weight=None): +def _cic_paint_impl(grid_mesh, positions, weight=1.): """ Paints positions onto mesh mesh: [nx, ny, nz] displacement field: [nx, ny, nz, 3] @@ -27,12 +27,10 @@ def _cic_paint_impl(grid_mesh, positions, weight=None): neighboor_coords = floor + connection kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - if weight is not None: - if jnp.isscalar(weight): - kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) - else: - kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), - kernel) + if jnp.isscalar(weight): + kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) + else: + kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel) neighboor_coords = jnp.mod( neighboor_coords.reshape([-1, 8, 3]).astype('int32'), @@ -48,7 +46,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None): @partial(jax.jit, static_argnums=(3, 4)) -def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): +def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None): + + if sharding is not None: + print(""" + WARNING : absolute painting is not recommended in multi-device mode. + Please use relative painting instead. + """) positions = positions.reshape((*grid_mesh.shape, 3)) @@ -57,9 +61,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None spec = sharding.spec if isinstance(sharding, NamedSharding) else P() + weight_spec = P() if jnp.isscalar(weight) else spec + grid_mesh = autoshmap(_cic_paint_impl, gpu_mesh=gpu_mesh, - in_specs=(spec, spec, P()), + in_specs=(spec, spec, weight_spec), out_specs=spec)(grid_mesh, positions, weight) grid_mesh = halo_exchange(grid_mesh, halo_extents=halo_extents, @@ -128,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight): positions: [npart, 2] weight: [npart] """ + positions = positions.reshape([-1, 2]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) @@ -136,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight): kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] if weight is not None: - kernel = kernel * weight[..., jnp.newaxis] + kernel = kernel * weight.reshape(*positions.shape[:-1]) neighboor_coords = jnp.mod( neighboor_coords.reshape([-1, 4, 2]).astype('int32'), @@ -151,7 +158,10 @@ def cic_paint_2d(mesh, positions, weight): return mesh -def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): +def _cic_paint_dx_impl(displacements, + weight=1., + halo_size=0, + chunk_size=2**24): halo_x, _ = halo_size[0] halo_y, _ = halo_size[1] @@ -190,13 +200,13 @@ def cic_paint_dx(displacements, gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None spec = sharding.spec if isinstance(sharding, NamedSharding) else P() + weight_spec = P() if jnp.isscalar(weight) else spec grid_mesh = autoshmap(partial(_cic_paint_dx_impl, halo_size=halo_size, - weight=weight, chunk_size=chunk_size), gpu_mesh=gpu_mesh, - in_specs=spec, - out_specs=spec)(displacements) + in_specs=(spec, weight_spec), + out_specs=spec)(displacements, weight) grid_mesh = halo_exchange(grid_mesh, halo_extents=halo_extents, diff --git a/jaxpm/pm.py b/jaxpm/pm.py index e34d584..9951e1c 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -172,8 +172,7 @@ def make_ode_fn(mesh_shape, return nbody_ode -def make_diffrax_ode(cosmo, - mesh_shape, +def make_diffrax_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None): @@ -183,6 +182,7 @@ def make_diffrax_ode(cosmo, state is a tuple (position, velocities) """ pos, vel = state + cosmo = args forces = pm_forces(pos, mesh_shape=mesh_shape, diff --git a/notebooks/02-Advanced_usage.ipynb b/notebooks/02-Advanced_usage.ipynb index cf7f611..9027ef2 100644 --- a/notebooks/02-Advanced_usage.ipynb +++ b/notebooks/02-Advanced_usage.ipynb @@ -32,7 +32,7 @@ "from jaxpm.painting import cic_paint , cic_paint_dx\n", "from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n", "from jaxpm.distributed import uniform_particles\n", - "from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve" + "from diffrax import PIDController, Tsit5, ODETerm, SaveAt, diffeqsolve" ] }, { @@ -41,10 +41,9 @@ "source": [ "### Particle Mesh Simulation with Diffrax Leapfrog Solver\n", "\n", - "In this setup, we use the `LeapfrogMidpoint` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation. The novelty here is the use of a **Leapfrog solver** from `diffrax` for efficient, memory-saving time integration.\n", + "In this setup, we use the `Tsit5` solver from the `diffrax` library to evolve particle displacements over time in our Particle Mesh simulation.\n", "\n", - "- **Leapfrog Integration**: This symplectic integrator is well-suited for simulations of gravitational dynamics, preserving energy over long timescales and allowing larger time steps without sacrificing accuracy.\n", - "- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which, combined with Leapfrog’s stability, enhances memory efficiency and speeds up computation.\n" + "- **Efficient Displacement Tracking**: We initialize only displacements (`dx`) rather than absolute positions, which uses a the `pmwd` cic_painting algorithm which is more memory efficient at the cost of being slightly slower\n" ] }, { @@ -84,10 +83,10 @@ " \n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", - " solver = LeapfrogMidpoint()\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False))\n", + " solver = Tsit5()\n", "\n", - " stepsize_controller = ConstantStepSize()\n", + " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n", " res = diffeqsolve(ode_fn,\n", " solver,\n", " t0=0.1,\n", @@ -257,10 +256,10 @@ " \n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape))\n", - " solver = LeapfrogMidpoint()\n", + " make_diffrax_ode(mesh_shape))\n", + " solver = Tsit5()\n", "\n", - " stepsize_controller = ConstantStepSize()\n", + " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n", " res = diffeqsolve(ode_fn,\n", " solver,\n", " t0=0.1,\n", diff --git a/notebooks/03-MultiGPU_PM_Halo.ipynb b/notebooks/03-MultiGPU_PM_Halo.ipynb index 0a652d2..be0bfb8 100644 --- a/notebooks/03-MultiGPU_PM_Halo.ipynb +++ b/notebooks/03-MultiGPU_PM_Halo.ipynb @@ -90,7 +90,7 @@ "\n", "This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n", "\n", - "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n", + "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n", "- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n", "- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n", "\n", @@ -99,21 +99,18 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "9edd2246", "metadata": {}, "outputs": [], "source": [ - "from jax.experimental.mesh_utils import create_device_mesh\n", "from jax.experimental.multihost_utils import process_allgather\n", - "from jax.sharding import Mesh, NamedSharding\n", - "from jax.sharding import PartitionSpec as P\n", + "from jax.sharding import PartitionSpec as P, NamedSharding\n", "\n", - "all_gather = partial(process_allgather, tiled=False)\n", + "all_gather = partial(process_allgather, tiled=True)\n", "\n", "pdims = (2, 4)\n", - "devices = create_device_mesh(pdims)\n", - "mesh = Mesh(devices, axis_names=('x', 'y'))\n", + "mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n", "sharding = NamedSharding(mesh, P('x', 'y'))" ] }, @@ -180,10 +177,10 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", - " solver = LeapfrogMidpoint()\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding))\n", + " solver = Tsit5()\n", "\n", - " stepsize_controller = ConstantStepSize()\n", + " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n", " res = diffeqsolve(ode_fn,\n", " solver,\n", " t0=0.1,\n", @@ -410,10 +407,10 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", - " solver = LeapfrogMidpoint()\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False , sharding=sharding , halo_size=halo_size))\n", + " solver = Tsit5()\n", "\n", - " stepsize_controller = ConstantStepSize()\n", + " stepsize_controller = PIDController(rtol=1e-6 , atol=1e-6)\n", " res = diffeqsolve(ode_fn,\n", " solver,\n", " t0=0.1,\n", @@ -689,7 +686,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.11.11" } }, "nbformat": 4, diff --git a/notebooks/04-MultiGPU_PM_Solvers.ipynb b/notebooks/04-MultiGPU_PM_Solvers.ipynb index 7671bc7..1c22f33 100644 --- a/notebooks/04-MultiGPU_PM_Solvers.ipynb +++ b/notebooks/04-MultiGPU_PM_Solvers.ipynb @@ -62,7 +62,7 @@ "\n", "This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n", "\n", - "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n", + "- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid.\n", "- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n", "- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n", "\n", @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -80,11 +80,10 @@ "from jax.sharding import Mesh, NamedSharding\n", "from jax.sharding import PartitionSpec as P\n", "\n", - "all_gather = partial(process_allgather, tiled=False)\n", + "all_gather = partial(process_allgather, tiled=True)\n", "\n", "pdims = (2, 4)\n", - "devices = create_device_mesh(pdims)\n", - "mesh = Mesh(devices, axis_names=('x', 'y'))\n", + "mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))\n", "sharding = NamedSharding(mesh, P('x', 'y'))" ] }, @@ -124,7 +123,7 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n", " solver = LeapfrogMidpoint()\n", "\n", " stepsize_controller = ConstantStepSize()\n", @@ -288,7 +287,7 @@ "\n", " # Evolve the simulation forward\n", " ode_fn = ODETerm(\n", - " make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n", + " make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))\n", " solver = Dopri5()\n", "\n", " stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n", diff --git a/notebooks/05-MultiHost_PM.py b/notebooks/05-MultiHost_PM.py index da3964e..c41d1cf 100644 --- a/notebooks/05-MultiHost_PM.py +++ b/notebooks/05-MultiHost_PM.py @@ -17,9 +17,8 @@ import jax_cosmo as jc import numpy as np from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt, diffeqsolve) -from jax.experimental.mesh_utils import create_device_mesh from jax.experimental.multihost_utils import process_allgather -from jax.sharding import Mesh, NamedSharding +from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P from jaxpm.kernels import interpolate_power_spectrum @@ -78,7 +77,7 @@ def parse_arguments(): def create_mesh_and_sharding(pdims): devices = create_device_mesh(pdims) - mesh = Mesh(devices, axis_names=('x', 'y')) + mesh = jax.make_mesh(pdims, axis_names=('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y')) return mesh, sharding @@ -106,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + make_diffrax_ode(mesh_shape, + paint_absolute_pos=False, + sharding=sharding, + halo_size=halo_size)) # Choose solver solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5() diff --git a/notebooks/README.md b/notebooks/README.md index 43d9d0b..872fdd4 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -37,3 +37,50 @@ Each notebook includes installation instructions and guidelines for configuring - **SLURM** for job scheduling on clusters (if running multi-host setups) > **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters. + +## Caveats + +### Cloud-in-Cell (CIC) Painting (Single Device) + +There is two ways to perform the CIC painting in JAXPM. The first one is to use the `cic_paint` which paints absolute particle positions to the mesh. The second one is to use the `cic_paint_dx` which paints relative particle positions to the mesh (using uniform particles). The absolute version is faster at the cost of more memory usage. + +inorder to use relative painting you need to : + + - Set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` + - Set `paint_absolute_pos` to `False` in `make_ode_fn` or `make_diffrax_ode` function from `jaxpm.pm` (it is True by default) + +Otherwise you set `particles` to the starting particles of your choice and leave `paint_absolute_pos` to `True` (default value). + +### Cloud-in-Cell (CIC) Painting (Multi Device) + +Both `cic_paint` and `cic_paint_dx` functions are available in multi-device mode. + +You need to set the arguments `sharding` and `halo_size` which is explained in the notebook [03-MultiGPU_PM_Halo.ipynb](03-MultiGPU_PM_Halo.ipynb). + +One thing to note that `cic_paint` is not as accurate as `cic_paint_dx` in multi-device mode and therefor is not recommended. + +Using relative painting in multi-device mode is just like in single device mode.\ +You need to set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` and set `paint_absolute_pos` to `False` + +### Distributed PM + +To run a distributed PM follow the examples in notebooks [03](03-MultiGPU_PM_Halo.ipynb) and [05](05-MultiHost_PM.ipynb) for multi-host. + +In short you need to set the arguments `sharding` and `halo_size` in `lpt` , `linear_field` the `make_ode` functions and `pm_forces` if you use it. + +Missmatching the shardings will give you errors and unexpected results. + +You can also use `normal_field` and `uniform_particles` from `jaxpm.pm.distributed` to create the fields and particles with a sharding. + +### Choosing the right pdims + +pdims are processor dimensions.\ +Explained more in the jaxdecomp paper [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp). + +For 8 devices there are three decompositions that are possible: +- (1 , 8) +- (2 , 4) , (4 , 2) +- (8 , 1) + +(1 , X) should be the fastest (2 , X) or (X , 2) is more accurate but slightly slower.\ +and (X , 1) is giving the least accurate results for some reason so it is not recommended. diff --git a/tests/test_against_fpm.py b/tests/test_against_fpm.py index 6d17939..5ebcbc2 100644 --- a/tests/test_against_fpm.py +++ b/tests/test_against_fpm.py @@ -76,7 +76,7 @@ def test_nbody_absolute(simulation_config, initial_conditions, a=lpt_scale_factor, order=order) - ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) solver = Dopri5() controller = PIDController(rtol=1e-8, @@ -95,6 +95,7 @@ def test_nbody_absolute(simulation_config, initial_conditions, t1=1.0, dt0=None, y0=y0, + args=cosmo, stepsize_controller=controller, saveat=saveat) @@ -121,8 +122,7 @@ def test_nbody_relative(simulation_config, initial_conditions, # Initial displacement dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) - ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) solver = Dopri5() controller = PIDController(rtol=1e-9, @@ -141,6 +141,7 @@ def test_nbody_relative(simulation_config, initial_conditions, t1=1.0, dt0=None, y0=y0, + args=cosmo, stepsize_controller=controller, saveat=saveat) diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index fd683ab..69c37ed 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -2,8 +2,11 @@ from conftest import initialize_distributed initialize_distributed() # ignore : E402 +from functools import partial # noqa : E402 + import jax # noqa : E402 import jax.numpy as jnp # noqa : E402 +import jax_cosmo as jc # noqa : E402 import pytest # noqa : E402 from diffrax import SaveAt # noqa : E402 from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve @@ -12,19 +15,31 @@ from jax import lax # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P # noqa : E402 +from jaxdecomp import get_fft_output_sharding from jaxpm.distributed import uniform_particles # noqa : E402 +from jaxpm.distributed import fft3d, ifft3d from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 -from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402 +from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402 -_TOLERANCE = 3.0 # 🙃🙃 +_TOLERANCE = 1e-1 # 🙃🙃 + +pdims = [(1, 8), (8, 1), (4, 2), (2, 4)] @pytest.mark.distributed @pytest.mark.parametrize("order", [1, 2]) +@pytest.mark.parametrize("pdims", pdims) @pytest.mark.parametrize("absolute_painting", [True, False]) def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, - absolute_painting): + pdims, absolute_painting): + + if absolute_painting: + pytest.skip("Absolute painting is not recommended in distributed mode") + + painting_str = "absolute" if absolute_painting else "relative" + print("=" * 50) + print(f"Running with {painting_str} painting and pdims {pdims} ...") mesh_shape, box_shape = simulation_config # SINGLE DEVICE RUN @@ -37,12 +52,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, particles, a=0.1, order=order) - ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) y0 = jnp.stack([particles + dx, p]) else: dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order) - ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape, + paint_absolute_pos=False)) y0 = jnp.stack([dx, p]) solver = Dopri5() @@ -60,6 +75,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, t1=1.0, dt0=None, y0=y0, + args=cosmo, stepsize_controller=controller, saveat=saveat) @@ -72,7 +88,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, print("Done with single device run") # MULTI DEVICE RUN - mesh = jax.make_mesh((1, 8), ('x', 'y')) + mesh = jax.make_mesh(pdims, ('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 @@ -94,8 +110,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(cosmo, - mesh_shape, + make_diffrax_ode(mesh_shape, halo_size=halo_size, sharding=sharding)) @@ -108,8 +123,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, halo_size=halo_size, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(cosmo, - mesh_shape, + make_diffrax_ode(mesh_shape, paint_absolute_pos=False, halo_size=halo_size, sharding=sharding)) @@ -130,16 +144,23 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, t1=1.0, dt0=None, y0=y0, + args=cosmo, stepsize_controller=controller, saveat=saveat) + final_field = solutions.ys[-1, 0] + print(f"Final field sharding is {final_field.sharding}") + + assert final_field.sharding.is_equivalent_to(sharding , ndim=3) \ + , f"Final field sharding is not correct .. should be {sharding} it is instead {final_field.sharding}" + if absolute_painting: multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), - solutions.ys[-1, 0], + final_field, halo_size=halo_size, sharding=sharding) else: - multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], + multi_device_final_field = cic_paint_dx(final_field, halo_size=halo_size, sharding=sharding) @@ -150,3 +171,230 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, print(f"MSE is {mse}") assert mse < _TOLERANCE + + +@pytest.mark.distributed +@pytest.mark.parametrize("order", [1, 2]) +@pytest.mark.parametrize("pdims", pdims) +def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, + order, nbody_from_lpt1, nbody_from_lpt2, pdims): + + mesh_shape, box_shape = simulation_config + # SINGLE DEVICE RUN + cosmo._workspace = {} + + mesh = jax.make_mesh(pdims, ('x', 'y')) + sharding = NamedSharding(mesh, P('x', 'y')) + halo_size = mesh_shape[0] // 2 + + initial_conditions = lax.with_sharding_constraint(initial_conditions, + sharding) + + print(f"sharded initial conditions {initial_conditions.sharding}") + cosmo._workspace = {} + + @jax.jit + def forward_model(initial_conditions, cosmo): + + dx, p, _ = lpt(cosmo, + initial_conditions, + a=0.1, + order=order, + halo_size=halo_size, + sharding=sharding) + ode_fn = ODETerm( + make_diffrax_ode(mesh_shape, + paint_absolute_pos=False, + halo_size=halo_size, + sharding=sharding)) + y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p) + + solver = Dopri5() + controller = PIDController(rtol=1e-8, + atol=1e-8, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + solutions = diffeqsolve(ode_fn, + solver, + t0=0.1, + t1=1.0, + dt0=None, + y0=y0, + args=cosmo, + stepsize_controller=controller, + saveat=saveat) + + multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], + halo_size=halo_size, + sharding=sharding) + + return multi_device_final_field + + @jax.jit + def model(initial_conditions, cosmo): + final_field = forward_model(initial_conditions, cosmo) + return MSE(final_field, + nbody_from_lpt1 if order == 1 else nbody_from_lpt2) + + obs_val = model(initial_conditions, cosmo) + + shifted_initial_conditions = initial_conditions + jax.random.normal( + jax.random.key(42), initial_conditions.shape) * 5 + + good_grads = jax.grad(model)(initial_conditions, cosmo) + off_grads = jax.grad(model)(shifted_initial_conditions, cosmo) + + assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding, + ndim=3) + assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding, + ndim=3) + + +@pytest.mark.distributed +@pytest.mark.parametrize("pdims", pdims) +def test_fwd_rev_gradients(cosmo, pdims): + + mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) + cosmo._workspace = {} + + mesh = jax.make_mesh(pdims, ('x', 'y')) + sharding = NamedSharding(mesh, P('x', 'y')) + halo_size = mesh_shape[0] // 2 + + initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape) + initial_conditions = lax.with_sharding_constraint(initial_conditions, + sharding) + print(f"sharded initial conditions {initial_conditions.sharding}") + cosmo._workspace = {} + + @partial(jax.jit, static_argnums=(2, 3, 4)) + def compute_forces(initial_conditions, + cosmo, + a=0.5, + halo_size=0, + sharding=None): + + paint_absolute_pos = False + particles = jnp.zeros_like(initial_conditions, + shape=(*initial_conditions.shape, 3)) + + a = jnp.atleast_1d(a) + E = jnp.sqrt(jc.background.Esqr(cosmo, a)) + + initial_conditions = jax.lax.with_sharding_constraint( + initial_conditions, sharding) + delta_k = fft3d(initial_conditions) + out_sharding = get_fft_output_sharding(sharding) + delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding) + + initial_force = pm_forces(particles, + delta=delta_k, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding) + + return initial_force[..., 0] + + forces = compute_forces(initial_conditions, + cosmo, + halo_size=halo_size, + sharding=sharding) + back_gradient = jax.jacrev(compute_forces)(initial_conditions, + cosmo, + halo_size=halo_size, + sharding=sharding) + fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions, + cosmo, + halo_size=halo_size, + sharding=sharding) + + print(f"Forces sharding is {forces.sharding}") + print(f"Backward gradient sharding is {back_gradient.sharding}") + print(f"Forward gradient sharding is {fwd_gradient.sharding}") + assert forces.sharding.is_equivalent_to(initial_conditions.sharding, + ndim=3) + assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to( + initial_conditions.sharding, ndim=3) + assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding, + ndim=3) + + +@pytest.mark.distributed +@pytest.mark.parametrize("pdims", pdims) +def test_vmap(cosmo, pdims): + + mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) + cosmo._workspace = {} + + mesh = jax.make_mesh(pdims, ('x', 'y')) + sharding = NamedSharding(mesh, P('x', 'y')) + halo_size = mesh_shape[0] // 2 + + single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42), + mesh_shape) + initial_conditions = lax.with_sharding_constraint( + single_dev_initial_conditions, sharding) + + single_ics = jnp.stack([ + single_dev_initial_conditions, single_dev_initial_conditions, + single_dev_initial_conditions + ]) + sharded_ics = jnp.stack( + [initial_conditions, initial_conditions, initial_conditions]) + print(f"unsharded initial conditions batch {single_ics.sharding}") + print(f"sharded initial conditions batch {sharded_ics.sharding}") + cosmo._workspace = {} + + @partial(jax.jit, static_argnums=(2, 3, 4)) + def compute_forces(initial_conditions, + cosmo, + a=0.5, + halo_size=0, + sharding=None): + + paint_absolute_pos = False + particles = jnp.zeros_like(initial_conditions, + shape=(*initial_conditions.shape, 3)) + + a = jnp.atleast_1d(a) + E = jnp.sqrt(jc.background.Esqr(cosmo, a)) + + initial_conditions = jax.lax.with_sharding_constraint( + initial_conditions, sharding) + delta_k = fft3d(initial_conditions) + out_sharding = get_fft_output_sharding(sharding) + delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding) + + initial_force = pm_forces(particles, + delta=delta_k, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding) + + return initial_force[..., 0] + + def fn(ic): + return compute_forces(ic, + cosmo, + halo_size=halo_size, + sharding=sharding) + + v_compute_forces = jax.vmap(fn) + + print(f"single_ics shape {single_ics.shape}") + print(f"sharded_ics shape {sharded_ics.shape}") + + single_dev_forces = v_compute_forces(single_ics) + sharded_forces = v_compute_forces(sharded_ics) + + assert single_dev_forces.ndim == 4 + assert sharded_forces.ndim == 4 + + print(f"Sharded forces {sharded_forces.sharding}") + + assert sharded_forces[0].sharding.is_equivalent_to( + initial_conditions.sharding, ndim=3) + assert sharded_forces.sharding.spec[0] == None diff --git a/tests/test_gradients.py b/tests/test_gradients.py index bb48920..1f611aa 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, particles, a=lpt_scale_factor, order=order) - ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) y0 = jnp.stack([particles + dx, p]) else: @@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, a=lpt_scale_factor, order=order) ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) y0 = jnp.stack([dx, p]) solver = Dopri5() @@ -66,6 +66,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, t1=1.0, dt0=None, y0=y0, + args=cosmo, adjoint=adjoint, stepsize_controller=controller, saveat=saveat)