diff --git a/jaxpm/_src/spmd_config.py b/jaxpm/_src/spmd_config.py index 147b97e..03dd814 100644 --- a/jaxpm/_src/spmd_config.py +++ b/jaxpm/_src/spmd_config.py @@ -1,8 +1,8 @@ from abc import ABCMeta, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from inspect import signature -from typing import Callable , Iterable +from typing import Callable, Iterable from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding @@ -25,7 +25,7 @@ class SPMDConfig(): @dataclass class OpsRegistry(): - list_of_ops: list = [] + list_of_ops: list = field(default_factory=list) def register_operator(self, cls): self.list_of_ops.append(cls) @@ -36,7 +36,9 @@ class OpsRegistry(): if base_sharding != None: for cls in self.list_of_ops: + print(f"Constctiong {cls} {cls.name}") impl = construct_operator(cls, base_sharding) + print(impl) setattr(self, cls.name, impl) def restore_operators(self, base_sharding=None): @@ -117,7 +119,6 @@ def register_operator(cls): pm_operators.register_operator(cls) - def check_prolog_function(prolog_fn, impl_fn): prolog_sig = signature(prolog_fn) impl_sig = signature(impl_fn) @@ -131,10 +132,14 @@ def check_prolog_function(prolog_fn, impl_fn): if isinstance(prolog_return_annotation, tuple): if len(prolog_return_annotation) != len(impl_sig.parameters): - raise RuntimeError("The number of outputs of the prolog does not match the number of inputs of the impl") + raise RuntimeError( + "The number of outputs of the prolog does not match the number of inputs of the impl" + ) else: if len(impl_sig.parameters) != 1: - raise RuntimeError("Prolog function output and impl function input count mismatch") + raise RuntimeError( + "Prolog function output and impl function input count mismatch" + ) return True @@ -152,18 +157,24 @@ def check_epilog_function(epilog_fn, impl_fn): if isinstance(impl_return_annotation, tuple): if len(impl_return_annotation) != len(epilog_sig.parameters): - raise RuntimeError("The number of outputs of the impl does not match the number of inputs of the epilog") + raise RuntimeError( + "The number of outputs of the impl does not match the number of inputs of the epilog" + ) else: if len(epilog_sig.parameters) != 1: - raise RuntimeError("Impl function output and epilog function input count mismatch") + raise RuntimeError( + "Impl function output and epilog function input count mismatch" + ) return True + def unpack_args(args): if not isinstance(args, Iterable): - args = (args,) + args = (args, ) return args + def construct_operator(cls, base_sharding=None): if base_sharding == None: @@ -173,6 +184,9 @@ def construct_operator(cls, base_sharding=None): if isinstance(cls, CustomPartionedOperator): impl = cls.multi_gpu_impl + print("here") + + return impl elif isinstance(cls, ShardedOperator): mesh = base_sharding.mesh @@ -223,5 +237,4 @@ def construct_operator(cls, base_sharding=None): elif isinstance(cls, CallBackOperator): impl = partial(cls.multi_gpu_impl, base_sharding=base_sharding) - - return impl + return impl