mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
fixes for SPMD config
This commit is contained in:
parent
f3599e73da
commit
1cade569af
1 changed files with 24 additions and 11 deletions
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Callable, Iterable
|
from typing import Callable, Iterable
|
||||||
|
@ -25,7 +25,7 @@ class SPMDConfig():
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpsRegistry():
|
class OpsRegistry():
|
||||||
list_of_ops: list = []
|
list_of_ops: list = field(default_factory=list)
|
||||||
|
|
||||||
def register_operator(self, cls):
|
def register_operator(self, cls):
|
||||||
self.list_of_ops.append(cls)
|
self.list_of_ops.append(cls)
|
||||||
|
@ -36,7 +36,9 @@ class OpsRegistry():
|
||||||
|
|
||||||
if base_sharding != None:
|
if base_sharding != None:
|
||||||
for cls in self.list_of_ops:
|
for cls in self.list_of_ops:
|
||||||
|
print(f"Constctiong {cls} {cls.name}")
|
||||||
impl = construct_operator(cls, base_sharding)
|
impl = construct_operator(cls, base_sharding)
|
||||||
|
print(impl)
|
||||||
setattr(self, cls.name, impl)
|
setattr(self, cls.name, impl)
|
||||||
|
|
||||||
def restore_operators(self, base_sharding=None):
|
def restore_operators(self, base_sharding=None):
|
||||||
|
@ -117,7 +119,6 @@ def register_operator(cls):
|
||||||
pm_operators.register_operator(cls)
|
pm_operators.register_operator(cls)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def check_prolog_function(prolog_fn, impl_fn):
|
def check_prolog_function(prolog_fn, impl_fn):
|
||||||
prolog_sig = signature(prolog_fn)
|
prolog_sig = signature(prolog_fn)
|
||||||
impl_sig = signature(impl_fn)
|
impl_sig = signature(impl_fn)
|
||||||
|
@ -131,10 +132,14 @@ def check_prolog_function(prolog_fn, impl_fn):
|
||||||
|
|
||||||
if isinstance(prolog_return_annotation, tuple):
|
if isinstance(prolog_return_annotation, tuple):
|
||||||
if len(prolog_return_annotation) != len(impl_sig.parameters):
|
if len(prolog_return_annotation) != len(impl_sig.parameters):
|
||||||
raise RuntimeError("The number of outputs of the prolog does not match the number of inputs of the impl")
|
raise RuntimeError(
|
||||||
|
"The number of outputs of the prolog does not match the number of inputs of the impl"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if len(impl_sig.parameters) != 1:
|
if len(impl_sig.parameters) != 1:
|
||||||
raise RuntimeError("Prolog function output and impl function input count mismatch")
|
raise RuntimeError(
|
||||||
|
"Prolog function output and impl function input count mismatch"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -152,18 +157,24 @@ def check_epilog_function(epilog_fn, impl_fn):
|
||||||
|
|
||||||
if isinstance(impl_return_annotation, tuple):
|
if isinstance(impl_return_annotation, tuple):
|
||||||
if len(impl_return_annotation) != len(epilog_sig.parameters):
|
if len(impl_return_annotation) != len(epilog_sig.parameters):
|
||||||
raise RuntimeError("The number of outputs of the impl does not match the number of inputs of the epilog")
|
raise RuntimeError(
|
||||||
|
"The number of outputs of the impl does not match the number of inputs of the epilog"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if len(epilog_sig.parameters) != 1:
|
if len(epilog_sig.parameters) != 1:
|
||||||
raise RuntimeError("Impl function output and epilog function input count mismatch")
|
raise RuntimeError(
|
||||||
|
"Impl function output and epilog function input count mismatch"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def unpack_args(args):
|
def unpack_args(args):
|
||||||
if not isinstance(args, Iterable):
|
if not isinstance(args, Iterable):
|
||||||
args = (args, )
|
args = (args, )
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def construct_operator(cls, base_sharding=None):
|
def construct_operator(cls, base_sharding=None):
|
||||||
|
|
||||||
if base_sharding == None:
|
if base_sharding == None:
|
||||||
|
@ -173,6 +184,9 @@ def construct_operator(cls, base_sharding=None):
|
||||||
|
|
||||||
if isinstance(cls, CustomPartionedOperator):
|
if isinstance(cls, CustomPartionedOperator):
|
||||||
impl = cls.multi_gpu_impl
|
impl = cls.multi_gpu_impl
|
||||||
|
print("here")
|
||||||
|
|
||||||
|
return impl
|
||||||
|
|
||||||
elif isinstance(cls, ShardedOperator):
|
elif isinstance(cls, ShardedOperator):
|
||||||
mesh = base_sharding.mesh
|
mesh = base_sharding.mesh
|
||||||
|
@ -223,5 +237,4 @@ def construct_operator(cls, base_sharding=None):
|
||||||
|
|
||||||
elif isinstance(cls, CallBackOperator):
|
elif isinstance(cls, CallBackOperator):
|
||||||
impl = partial(cls.multi_gpu_impl, base_sharding=base_sharding)
|
impl = partial(cls.multi_gpu_impl, base_sharding=base_sharding)
|
||||||
|
|
||||||
return impl
|
return impl
|
||||||
|
|
Loading…
Add table
Reference in a new issue