From 45b2c7f0d2932a934e2555b921880a04e34ed2d2 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 22 Oct 2024 17:30:56 -0400 Subject: [PATCH] By default check_rep is false for shard_map --- jaxpm/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 05444d9..ab85856 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -19,7 +19,7 @@ def autoshmap( gpu_mesh: Mesh | None, in_specs: Specs, out_specs: Specs, - check_rep: bool = True, + check_rep: bool = False, auto: frozenset[AxisName] = frozenset()) -> Callable: """Helper function to wrap the provided function in a shard map if the code is being executed in a mesh context."""