mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-15 10:21:11 +00:00
update cic read halo size and notebooks examples
This commit is contained in:
parent
d4049e5db4
commit
e7112e0c25
5 changed files with 161 additions and 176 deletions
|
@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None):
|
|||
axis=-1)
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed, sharding=None):
|
||||
def normal_field(mesh_shape, seed, sharding=None , dtype='float32'):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
|
@ -190,9 +190,9 @@ def normal_field(mesh_shape, seed, sharding=None):
|
|||
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||
|
||||
return shard_map(
|
||||
partial(normal, shape=local_mesh_shape, dtype='float32'),
|
||||
partial(normal, shape=local_mesh_shape, dtype=dtype),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=P(None),
|
||||
out_specs=spec)(keys) # yapf: disable
|
||||
else:
|
||||
return jax.random.normal(shape=mesh_shape, key=seed)
|
||||
return jax.random.normal(shape=mesh_shape, key=seed, dtype=dtype)
|
||||
|
|
|
@ -167,7 +167,7 @@ def _cic_paint_dx_impl(displacements,
|
|||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = displacements.shape
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
|
||||
if not jnp.isscalar(weight):
|
||||
if weight.shape != original_shape[:-1]:
|
||||
raise ValueError("Weight shape must match particle shape")
|
||||
|
@ -185,7 +185,7 @@ def _cic_paint_dx_impl(displacements,
|
|||
return scatter(pmid.reshape([-1, 3]),
|
||||
displacements.reshape([-1, 3]),
|
||||
particle_mesh,
|
||||
chunk_size=2**24,
|
||||
chunk_size=chunk_size,
|
||||
val=weight)
|
||||
|
||||
|
||||
|
@ -240,6 +240,7 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
|||
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
halo_size = jax.tree.map(lambda x: x//2, halo_size)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue