diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 05444d9..ab85856 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -19,7 +19,7 @@ def autoshmap( gpu_mesh: Mesh | None, in_specs: Specs, out_specs: Specs, - check_rep: bool = True, + check_rep: bool = False, auto: frozenset[AxisName] = frozenset()) -> Callable: """Helper function to wrap the provided function in a shard map if the code is being executed in a mesh context."""