fixes for SPMD config

This commit is contained in:
Wassim KABALAN 2024-07-09 02:35:37 +02:00
parent f3599e73da
commit 1cade569af

View file

@ -1,8 +1,8 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from typing import Callable , Iterable from typing import Callable, Iterable
from jax.experimental.shard_map import shard_map from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding from jax.sharding import NamedSharding
@ -25,7 +25,7 @@ class SPMDConfig():
@dataclass @dataclass
class OpsRegistry(): class OpsRegistry():
list_of_ops: list = [] list_of_ops: list = field(default_factory=list)
def register_operator(self, cls): def register_operator(self, cls):
self.list_of_ops.append(cls) self.list_of_ops.append(cls)
@ -36,7 +36,9 @@ class OpsRegistry():
if base_sharding != None: if base_sharding != None:
for cls in self.list_of_ops: for cls in self.list_of_ops:
print(f"Constctiong {cls} {cls.name}")
impl = construct_operator(cls, base_sharding) impl = construct_operator(cls, base_sharding)
print(impl)
setattr(self, cls.name, impl) setattr(self, cls.name, impl)
def restore_operators(self, base_sharding=None): def restore_operators(self, base_sharding=None):
@ -117,7 +119,6 @@ def register_operator(cls):
pm_operators.register_operator(cls) pm_operators.register_operator(cls)
def check_prolog_function(prolog_fn, impl_fn): def check_prolog_function(prolog_fn, impl_fn):
prolog_sig = signature(prolog_fn) prolog_sig = signature(prolog_fn)
impl_sig = signature(impl_fn) impl_sig = signature(impl_fn)
@ -131,10 +132,14 @@ def check_prolog_function(prolog_fn, impl_fn):
if isinstance(prolog_return_annotation, tuple): if isinstance(prolog_return_annotation, tuple):
if len(prolog_return_annotation) != len(impl_sig.parameters): if len(prolog_return_annotation) != len(impl_sig.parameters):
raise RuntimeError("The number of outputs of the prolog does not match the number of inputs of the impl") raise RuntimeError(
"The number of outputs of the prolog does not match the number of inputs of the impl"
)
else: else:
if len(impl_sig.parameters) != 1: if len(impl_sig.parameters) != 1:
raise RuntimeError("Prolog function output and impl function input count mismatch") raise RuntimeError(
"Prolog function output and impl function input count mismatch"
)
return True return True
@ -152,18 +157,24 @@ def check_epilog_function(epilog_fn, impl_fn):
if isinstance(impl_return_annotation, tuple): if isinstance(impl_return_annotation, tuple):
if len(impl_return_annotation) != len(epilog_sig.parameters): if len(impl_return_annotation) != len(epilog_sig.parameters):
raise RuntimeError("The number of outputs of the impl does not match the number of inputs of the epilog") raise RuntimeError(
"The number of outputs of the impl does not match the number of inputs of the epilog"
)
else: else:
if len(epilog_sig.parameters) != 1: if len(epilog_sig.parameters) != 1:
raise RuntimeError("Impl function output and epilog function input count mismatch") raise RuntimeError(
"Impl function output and epilog function input count mismatch"
)
return True return True
def unpack_args(args): def unpack_args(args):
if not isinstance(args, Iterable): if not isinstance(args, Iterable):
args = (args,) args = (args, )
return args return args
def construct_operator(cls, base_sharding=None): def construct_operator(cls, base_sharding=None):
if base_sharding == None: if base_sharding == None:
@ -173,6 +184,9 @@ def construct_operator(cls, base_sharding=None):
if isinstance(cls, CustomPartionedOperator): if isinstance(cls, CustomPartionedOperator):
impl = cls.multi_gpu_impl impl = cls.multi_gpu_impl
print("here")
return impl
elif isinstance(cls, ShardedOperator): elif isinstance(cls, ShardedOperator):
mesh = base_sharding.mesh mesh = base_sharding.mesh
@ -223,5 +237,4 @@ def construct_operator(cls, base_sharding=None):
elif isinstance(cls, CallBackOperator): elif isinstance(cls, CallBackOperator):
impl = partial(cls.multi_gpu_impl, base_sharding=base_sharding) impl = partial(cls.multi_gpu_impl, base_sharding=base_sharding)
return impl return impl