fixes for SPMD config

This commit is contained in:
Wassim KABALAN 2024-07-09 02:35:37 +02:00
parent f3599e73da
commit 1cade569af

View file

@ -1,8 +1,8 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from inspect import signature
from typing import Callable , Iterable
from typing import Callable, Iterable
from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding
@ -25,7 +25,7 @@ class SPMDConfig():
@dataclass
class OpsRegistry():
list_of_ops: list = []
list_of_ops: list = field(default_factory=list)
def register_operator(self, cls):
self.list_of_ops.append(cls)
@ -36,7 +36,9 @@ class OpsRegistry():
if base_sharding != None:
for cls in self.list_of_ops:
print(f"Constctiong {cls} {cls.name}")
impl = construct_operator(cls, base_sharding)
print(impl)
setattr(self, cls.name, impl)
def restore_operators(self, base_sharding=None):
@ -117,7 +119,6 @@ def register_operator(cls):
pm_operators.register_operator(cls)
def check_prolog_function(prolog_fn, impl_fn):
prolog_sig = signature(prolog_fn)
impl_sig = signature(impl_fn)
@ -131,10 +132,14 @@ def check_prolog_function(prolog_fn, impl_fn):
if isinstance(prolog_return_annotation, tuple):
if len(prolog_return_annotation) != len(impl_sig.parameters):
raise RuntimeError("The number of outputs of the prolog does not match the number of inputs of the impl")
raise RuntimeError(
"The number of outputs of the prolog does not match the number of inputs of the impl"
)
else:
if len(impl_sig.parameters) != 1:
raise RuntimeError("Prolog function output and impl function input count mismatch")
raise RuntimeError(
"Prolog function output and impl function input count mismatch"
)
return True
@ -152,18 +157,24 @@ def check_epilog_function(epilog_fn, impl_fn):
if isinstance(impl_return_annotation, tuple):
if len(impl_return_annotation) != len(epilog_sig.parameters):
raise RuntimeError("The number of outputs of the impl does not match the number of inputs of the epilog")
raise RuntimeError(
"The number of outputs of the impl does not match the number of inputs of the epilog"
)
else:
if len(epilog_sig.parameters) != 1:
raise RuntimeError("Impl function output and epilog function input count mismatch")
raise RuntimeError(
"Impl function output and epilog function input count mismatch"
)
return True
def unpack_args(args):
if not isinstance(args, Iterable):
args = (args,)
args = (args, )
return args
def construct_operator(cls, base_sharding=None):
if base_sharding == None:
@ -173,6 +184,9 @@ def construct_operator(cls, base_sharding=None):
if isinstance(cls, CustomPartionedOperator):
impl = cls.multi_gpu_impl
print("here")
return impl
elif isinstance(cls, ShardedOperator):
mesh = base_sharding.mesh
@ -223,5 +237,4 @@ def construct_operator(cls, base_sharding=None):
elif isinstance(cls, CallBackOperator):
impl = partial(cls.multi_gpu_impl, base_sharding=base_sharding)
return impl
return impl