mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
fixes for SPMD config
This commit is contained in:
parent
f3599e73da
commit
1cade569af
1 changed files with 24 additions and 11 deletions
|
@ -1,8 +1,8 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Callable , Iterable
|
||||
from typing import Callable, Iterable
|
||||
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import NamedSharding
|
||||
|
@ -25,7 +25,7 @@ class SPMDConfig():
|
|||
|
||||
@dataclass
|
||||
class OpsRegistry():
|
||||
list_of_ops: list = []
|
||||
list_of_ops: list = field(default_factory=list)
|
||||
|
||||
def register_operator(self, cls):
|
||||
self.list_of_ops.append(cls)
|
||||
|
@ -36,7 +36,9 @@ class OpsRegistry():
|
|||
|
||||
if base_sharding != None:
|
||||
for cls in self.list_of_ops:
|
||||
print(f"Constctiong {cls} {cls.name}")
|
||||
impl = construct_operator(cls, base_sharding)
|
||||
print(impl)
|
||||
setattr(self, cls.name, impl)
|
||||
|
||||
def restore_operators(self, base_sharding=None):
|
||||
|
@ -117,7 +119,6 @@ def register_operator(cls):
|
|||
pm_operators.register_operator(cls)
|
||||
|
||||
|
||||
|
||||
def check_prolog_function(prolog_fn, impl_fn):
|
||||
prolog_sig = signature(prolog_fn)
|
||||
impl_sig = signature(impl_fn)
|
||||
|
@ -131,10 +132,14 @@ def check_prolog_function(prolog_fn, impl_fn):
|
|||
|
||||
if isinstance(prolog_return_annotation, tuple):
|
||||
if len(prolog_return_annotation) != len(impl_sig.parameters):
|
||||
raise RuntimeError("The number of outputs of the prolog does not match the number of inputs of the impl")
|
||||
raise RuntimeError(
|
||||
"The number of outputs of the prolog does not match the number of inputs of the impl"
|
||||
)
|
||||
else:
|
||||
if len(impl_sig.parameters) != 1:
|
||||
raise RuntimeError("Prolog function output and impl function input count mismatch")
|
||||
raise RuntimeError(
|
||||
"Prolog function output and impl function input count mismatch"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -152,18 +157,24 @@ def check_epilog_function(epilog_fn, impl_fn):
|
|||
|
||||
if isinstance(impl_return_annotation, tuple):
|
||||
if len(impl_return_annotation) != len(epilog_sig.parameters):
|
||||
raise RuntimeError("The number of outputs of the impl does not match the number of inputs of the epilog")
|
||||
raise RuntimeError(
|
||||
"The number of outputs of the impl does not match the number of inputs of the epilog"
|
||||
)
|
||||
else:
|
||||
if len(epilog_sig.parameters) != 1:
|
||||
raise RuntimeError("Impl function output and epilog function input count mismatch")
|
||||
raise RuntimeError(
|
||||
"Impl function output and epilog function input count mismatch"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def unpack_args(args):
|
||||
if not isinstance(args, Iterable):
|
||||
args = (args,)
|
||||
args = (args, )
|
||||
return args
|
||||
|
||||
|
||||
def construct_operator(cls, base_sharding=None):
|
||||
|
||||
if base_sharding == None:
|
||||
|
@ -173,6 +184,9 @@ def construct_operator(cls, base_sharding=None):
|
|||
|
||||
if isinstance(cls, CustomPartionedOperator):
|
||||
impl = cls.multi_gpu_impl
|
||||
print("here")
|
||||
|
||||
return impl
|
||||
|
||||
elif isinstance(cls, ShardedOperator):
|
||||
mesh = base_sharding.mesh
|
||||
|
@ -223,5 +237,4 @@ def construct_operator(cls, base_sharding=None):
|
|||
|
||||
elif isinstance(cls, CallBackOperator):
|
||||
impl = partial(cls.multi_gpu_impl, base_sharding=base_sharding)
|
||||
|
||||
return impl
|
||||
|
|
Loading…
Add table
Reference in a new issue