diff --git a/jaxpm/growth.py b/jaxpm/growth.py index 5b6908c..8194b06 100644 --- a/jaxpm/growth.py +++ b/jaxpm/growth.py @@ -587,5 +587,5 @@ def dGf2a(cosmo, a): cache = cosmo._workspace['background.growth_factor'] f2p = cache['h2'] / cache['a'] * cache['g2'] f2p = interp(np.log(a), np.log(cache['a']), f2p) - E = E(cosmo, a) - return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f) + E_a = E(cosmo, a) + return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E_a * D2f) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index e9bfaef..b41f261 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -1,12 +1,11 @@ from functools import partial -import jax import jax.numpy as jnp import jax_cosmo as jc from jax.sharding import PartitionSpec as P from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d, - normal_field) + normal_field,zeros) from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second, growth_rate, growth_rate_second) from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, @@ -24,23 +23,24 @@ def pm_forces(positions, """ Computes gravitational forces on particles using a PM scheme """ - print(f"pm_forces particles are {positions}") - original_shape = positions.shape if mesh_shape is None: assert (delta is not None),\ "If mesh_shape is not provided, delta should be provided" mesh_shape = delta.shape - positions = positions.reshape((*mesh_shape, 3)) if paint_particles: - paint_fn = partial(cic_paint, grid_mesh=jnp.zeros(mesh_shape)) - read_fn = partial(cic_read, positions=positions) + paint_fn = lambda x: cic_paint( + zeros(mesh_shape,sharding), x , halo_size=halo_size, sharding=sharding) + read_fn = lambda x: cic_read( + x, positions, halo_size=halo_size, sharding=sharding) else: - paint_fn = cic_paint_dx - read_fn = cic_read_dx + paint_fn = partial(cic_paint_dx, + halo_size=halo_size, + sharding=sharding) + read_fn = partial(cic_read_dx, halo_size=halo_size, sharding=sharding) if delta is None: - field = paint_fn(positions, halo_size=halo_size, sharding=sharding) + field = paint_fn(positions) delta_k = fft3d(field) elif jnp.isrealobj(delta): delta_k = fft3d(delta) @@ -54,8 +54,7 @@ def pm_forces(positions, # Computes gravitational forces forces = jnp.stack([ read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k), - halo_size=halo_size, - sharding=sharding) for i in range(3)], axis=-1) # yapf: disable + ) for i in range(3)], axis=-1) # yapf: disable return forces @@ -71,19 +70,7 @@ def lpt(cosmo, Computes first and second order LPT displacement and momentum, e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258) """ - print(f"particles are {particles}") - gpu_mesh = sharding.mesh if sharding is not None else None - spec = sharding.spec if sharding is not None else P() - local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable - paint_particles = True - original_shape = particles.shape if particles is not None else (*initial_conditions.shape, 3) # yapf: disable - if particles is None: - paint_particles = False - particles = autoshmap( - partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), - gpu_mesh=gpu_mesh, - in_specs=(), - out_specs=spec)() # yapf: disable + paint_particles = particles is not None a = jnp.atleast_1d(a) E = jnp.sqrt(jc.background.Esqr(cosmo, a)) @@ -107,7 +94,7 @@ def lpt(cosmo, # Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)... # shear_ii = jnp.fft.irfftn(- ki**2 * pot_k) nabla_i_nabla_i = gradient_kernel(kvec, i)**2 - shear_ii = fft3d(nabla_i_nabla_i * pot_k) + shear_ii = ifft3d(nabla_i_nabla_i * pot_k) delta2 += shear_ii * shear_acc shear_acc += shear_ii @@ -117,10 +104,10 @@ def lpt(cosmo, # delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2 nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel( kvec, j) - delta2 -= fft3d(nabla_i_nabla_j * pot_k)**2 + delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2 delta_k2 = fft3d(delta2) - init_force2 = pm_forces(displacement, + init_force2 = pm_forces(particles, delta=delta_k2, paint_particles=paint_particles, halo_size=halo_size, @@ -134,7 +121,7 @@ def lpt(cosmo, p += p2 f += f2 - return dx.reshape(original_shape), p, f + return dx, p, f def linear_field(mesh_shape, box_size, pk, seed, sharding=None): @@ -155,17 +142,20 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None): return field -def make_ode_fn(mesh_shape, halo_size=0, sharding=None): +def make_ode_fn(mesh_shape, particles=None, halo_size=0, sharding=None): def nbody_ode(state, a, cosmo): """ state is a tuple (position, velocities) """ pos, vel = state + paint_particles = particles is not None - forces = pm_forces( - pos, mesh_shape=mesh_shape, halo_size=halo_size, - sharding=sharding) * 1.5 * cosmo.Omega_m + forces = pm_forces(pos, + mesh_shape=mesh_shape, + paint_particles=paint_particles, + halo_size=halo_size, + sharding=sharding) * 1.5 * cosmo.Omega_m # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel