mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
Create operator infra
This commit is contained in:
parent
18d0bec2b5
commit
e708f5b176
1 changed files with 227 additions and 0 deletions
227
jaxpm/_src/spmd_config.py
Normal file
227
jaxpm/_src/spmd_config.py
Normal file
|
@ -0,0 +1,227 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Callable , Iterable
|
||||
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import NamedSharding
|
||||
|
||||
|
||||
@dataclass
|
||||
class SPMDConfig():
|
||||
sharding: NamedSharding
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
pm_operators.construct_operators(self.sharding)
|
||||
self.sharding.mesh.__enter__()
|
||||
return self.sharding.mesh
|
||||
|
||||
def __exit__(self, *exc_details):
|
||||
self.sharding.mesh.__exit__(*exc_details)
|
||||
pm_operators.restore_operators(self.sharding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpsRegistry():
|
||||
list_of_ops: list = []
|
||||
|
||||
def register_operator(self, cls):
|
||||
self.list_of_ops.append(cls)
|
||||
# Register single gpu by default
|
||||
setattr(self, cls.name, cls.single_gpu_impl)
|
||||
|
||||
def construct_operators(self, base_sharding=None):
|
||||
|
||||
if base_sharding != None:
|
||||
for cls in self.list_of_ops:
|
||||
impl = construct_operator(cls, base_sharding)
|
||||
setattr(self, cls.name, impl)
|
||||
|
||||
def restore_operators(self, base_sharding=None):
|
||||
|
||||
if base_sharding != None:
|
||||
for cls in self.list_of_ops:
|
||||
setattr(self, cls.name, cls.single_gpu_impl)
|
||||
|
||||
|
||||
pm_operators = OpsRegistry()
|
||||
|
||||
|
||||
class CustomPartionedOperator(metaclass=ABCMeta):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def single_gpu_impl():
|
||||
return NotImplemented
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def multi_gpu_impl():
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class CallBackOperator(metaclass=ABCMeta):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def single_gpu_impl():
|
||||
return NotImplemented
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def multi_gpu_impl():
|
||||
return NotImplemented
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def shardings_to_use_in_impl():
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class ShardedOperator(metaclass=ABCMeta):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def single_gpu_impl():
|
||||
return NotImplemented
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def multi_gpu_prolog():
|
||||
return NotImplemented
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def multi_gpu_epilog():
|
||||
return NotImplemented
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def multi_gpu_impl():
|
||||
return NotImplemented
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def infer_sharding_from_base_sharding(base_sharding=None):
|
||||
return NotImplemented
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_aux_input_from_base_sharding(base_sharding=None):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def register_operator(cls):
|
||||
pm_operators.register_operator(cls)
|
||||
|
||||
|
||||
|
||||
def check_prolog_function(prolog_fn, impl_fn):
|
||||
prolog_sig = signature(prolog_fn)
|
||||
impl_sig = signature(impl_fn)
|
||||
|
||||
if len(prolog_sig.parameters) == 0 and prolog_fn() == NotImplemented:
|
||||
return False
|
||||
|
||||
prolog_return_annotation = prolog_sig.return_annotation
|
||||
if prolog_return_annotation is signature.empty:
|
||||
raise RuntimeError("Prolog function must have a return annotation")
|
||||
|
||||
if isinstance(prolog_return_annotation, tuple):
|
||||
if len(prolog_return_annotation) != len(impl_sig.parameters):
|
||||
raise RuntimeError("The number of outputs of the prolog does not match the number of inputs of the impl")
|
||||
else:
|
||||
if len(impl_sig.parameters) != 1:
|
||||
raise RuntimeError("Prolog function output and impl function input count mismatch")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_epilog_function(epilog_fn, impl_fn):
|
||||
epilog_sig = signature(epilog_fn)
|
||||
impl_sig = signature(impl_fn)
|
||||
|
||||
if len(epilog_sig.parameters) == 0 and epilog_fn() == NotImplemented:
|
||||
return False
|
||||
|
||||
impl_return_annotation = impl_sig.return_annotation
|
||||
if impl_return_annotation is signature.empty:
|
||||
raise RuntimeError("Impl function must have a return annotation")
|
||||
|
||||
if isinstance(impl_return_annotation, tuple):
|
||||
if len(impl_return_annotation) != len(epilog_sig.parameters):
|
||||
raise RuntimeError("The number of outputs of the impl does not match the number of inputs of the epilog")
|
||||
else:
|
||||
if len(epilog_sig.parameters) != 1:
|
||||
raise RuntimeError("Impl function output and epilog function input count mismatch")
|
||||
|
||||
return True
|
||||
|
||||
def unpack_args(args):
|
||||
if not isinstance(args, Iterable):
|
||||
args = (args,)
|
||||
return args
|
||||
|
||||
def construct_operator(cls, base_sharding=None):
|
||||
|
||||
if base_sharding == None:
|
||||
return
|
||||
elif not isinstance(base_sharding, NamedSharding):
|
||||
raise ValueError("base_sharding must be of type NamedSharding or None")
|
||||
|
||||
if isinstance(cls, CustomPartionedOperator):
|
||||
impl = cls.multi_gpu_impl
|
||||
|
||||
elif isinstance(cls, ShardedOperator):
|
||||
mesh = base_sharding.mesh
|
||||
in_spec, out_spec = cls.infer_sharding_from_base_sharding(
|
||||
base_sharding)
|
||||
__aux_input = cls.get_aux_input_from_base_sharding(base_sharding)
|
||||
|
||||
if __aux_input is not None:
|
||||
multi_gpu_impl = partial(cls.multi_gpu_impl,
|
||||
__aux_input=__aux_input)
|
||||
else:
|
||||
multi_gpu_impl = cls.multi_gpu_impl
|
||||
|
||||
multi_gpu_prolog = None
|
||||
multi_gpu_epilog = None
|
||||
if check_prolog_function(cls.multi_gpu_prolog, cls.multi_gpu_impl):
|
||||
if __aux_input is not None:
|
||||
multi_gpu_prolog = partial(cls.multi_gpu_prolog,
|
||||
__aux_input=__aux_input)
|
||||
else:
|
||||
multi_gpu_prolog = cls.multi_gpu_prolog
|
||||
if check_epilog_function(cls.multi_gpu_epilog, cls.multi_gpu_impl):
|
||||
if __aux_input is not None:
|
||||
multi_gpu_epilog = partial(cls.multi_gpu_epilog,
|
||||
__aux_input=__aux_input)
|
||||
else:
|
||||
multi_gpu_epilog = cls.multi_gpu_epilog
|
||||
|
||||
sharded_impl = shard_map(multi_gpu_impl,
|
||||
mesh=mesh,
|
||||
in_spec=in_spec,
|
||||
out_spec=out_spec,
|
||||
check_rep=False)
|
||||
|
||||
def impl(*params, **kwargs):
|
||||
if multi_gpu_prolog is not None:
|
||||
args = multi_gpu_prolog(*params, **kwargs)
|
||||
out = sharded_impl(*unpack_args(args))
|
||||
else:
|
||||
out = sharded_impl(*params, **kwargs)
|
||||
|
||||
if multi_gpu_epilog is not None:
|
||||
out = multi_gpu_epilog(*unpack_args(out))
|
||||
|
||||
return out
|
||||
|
||||
return impl
|
||||
|
||||
elif isinstance(cls, CallBackOperator):
|
||||
impl = partial(cls.multi_gpu_impl, base_sharding=base_sharding)
|
||||
|
||||
return impl
|
Loading…
Add table
Reference in a new issue