mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-14 12:01:12 +00:00
final touches
This commit is contained in:
parent
4e4d3745f0
commit
e1daa8cba4
3 changed files with 5 additions and 4 deletions
|
@ -258,7 +258,6 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
|
|||
def test_fwd_rev_gradients(cosmo, pdims):
|
||||
|
||||
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
|
@ -328,7 +327,6 @@ def test_fwd_rev_gradients(cosmo, pdims):
|
|||
def test_vmap(cosmo, pdims):
|
||||
|
||||
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue