diff --git a/jaxpm/__init__.py b/jaxpm/__init__.py index e69de29..2598fa4 100644 --- a/jaxpm/__init__.py +++ b/jaxpm/__init__.py @@ -0,0 +1,8 @@ +# Execute the register_operator functions +import jaxpm._src.base_ops +import jaxpm._src.painting_ops +import jaxpm.ops +import jaxpm.painting +from jaxpm._src.spmd_config import SPMDConfig + +__all__ = ['SPMDConfig']