Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RegistryMixin improved alias management #404

Merged
merged 11 commits into from
Jan 23, 2024
115 changes: 99 additions & 16 deletions src/sparsezoo/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
"register",
"get_from_registry",
"registered_names",
"registered_aliases",
]


_ALIAS_REGISTRY: Dict[str, str] = defaultdict(dict)
_REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict)


Expand Down Expand Up @@ -65,7 +67,7 @@ class Cifar(Dataset):
pass

# register with multiple aliases
@Dataset.register(name=["cifar-10-dataset", "cifar-100-dataset"])
@Dataset.register(alias=["cifar-10-dataset", "cifar_100_dataset"])
Satrat marked this conversation as resolved.
Show resolved Hide resolved
class Cifar(Dataset):
pass

Expand All @@ -75,14 +77,23 @@ class Cifar(Dataset):
# load from custom file that implements a dataset
mnist = Dataset.load_from_registry("/path/to/mnnist_dataset.py:MnistDataset")
```

Note, that any name or alias that is being registered, will be also recognized
when all hyphens are replaced with underscores and vice versa.
For example, if a class is registered:
- as "cifar-10-dataset", it will be also recognized as "cifar_10_dataset"
- as "cifar_10_dataset", it will be also recognized as "cifar-10-dataset"
etc.
"""

# set to True in child class to add check that registered/retrieved values
# implement the class it is registered to
registry_requires_subclass: bool = False

@classmethod
def register(cls, name: Union[List[str], str, None] = None):
def register(
cls, name: Optional[str] = None, alias: Union[List[str], str, None] = None
):
"""
Decorator for registering a value (ie class or function) wrapped by this
decorator to the base class (class that .register is called from)
Expand All @@ -93,28 +104,30 @@ def register(cls, name: Union[List[str], str, None] = None):
"""

def decorator(value: Any):
cls.register_value(value, name=name)
cls.register_value(value, name=name, alias=alias)
return value

return decorator

@classmethod
def register_value(cls, value: Any, name: Union[List[str], str, None] = None):
def register_value(
cls, value: Any, name: str, alias: Union[str, List[str], None] = None
):
"""
Registers the given value to the class `.register_value` is called from
:param value: value to register
:param name: name or list of names to register the wrapped value as,
:param name: name to register the wrapped value as,
defaults to value.__name__
:param alias: alias or list of aliases to register the wrapped value as,
defaults to None
"""
names = name if isinstance(name, list) else [name]

for name in names:
register(
parent_class=cls,
value=value,
name=name,
require_subclass=cls.registry_requires_subclass,
)
register(
parent_class=cls,
value=value,
name=name,
alias=alias,
require_subclass=cls.registry_requires_subclass,
)

@classmethod
def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
Expand Down Expand Up @@ -149,24 +162,36 @@ def registered_names(cls) -> List[str]:
"""
return registered_names(cls)

@classmethod
def registered_aliases(cls) -> List[str]:
"""
:return: list of all aliases registered to this class
"""
return registered_aliases(cls)


def register(
parent_class: Type,
value: Any,
name: Optional[str] = None,
alias: Union[List[str], str, None] = None,
require_subclass: bool = False,
):
"""
:param parent_class: class to register the name under
:param value: the value to register
:param name: name to register the wrapped value as, defaults to value.__name__
:param alias: alias or list of aliases to register the wrapped value as,
defaults to None
:param require_subclass: require that value is a subclass of the class this
method is called from
"""
if name is None:
# default name
name = value.__name__

register_alias(name=name, alias=alias)

if require_subclass:
_validate_subclass(parent_class, value)

Expand Down Expand Up @@ -201,12 +226,15 @@ def get_from_registry(
retrieved_value = _import_and_get_value_from_module(module_path, value_name)
else:
# look up name in registry
name = _ALIAS_REGISTRY.get(name)
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
retrieved_value = _REGISTRY[parent_class].get(name)
if retrieved_value is None:
raise KeyError(
f"Unable to find {name} registered under type {parent_class}. "
f"Unable to find {name} registered under type {parent_class}.\n"
f"Registered values for {parent_class}: "
f"{registered_names(parent_class)}"
f"{registered_names(parent_class)}\n"
f"Registered aliases for {parent_class}: "
f"{registered_aliases(parent_class)}"
)

if require_subclass:
Expand All @@ -223,6 +251,51 @@ def registered_names(parent_class: Type) -> List[str]:
return list(_REGISTRY[parent_class].keys())


def registered_aliases(parent_class: Type) -> List[str]:
"""
:param parent_class: class to look up the registry of
:return: all aliases registered to the given class
"""
alias_keys = set(_ALIAS_REGISTRY.keys())
names_keys = set(_REGISTRY[parent_class].keys())
return list(alias_keys.difference(names_keys))


def register_alias(name: str, alias: Union[str, List[str], None] = None):
"""
Updates the mapping from the alias(es) to the given name.
If the alias is None, the name is used as the alias.

Note, that the number of actual added alias(es) will be potentially
larger than the number of given aliases, since the function adds
variants of the alias with hyphens and underscores.

Examples:

If alias = ["alias1", "alias2"]
```
_ALIAS_REGISTRY = {..., "alias1": "name", "alias2": "name"}
```
If alias = None
```
_ALIAS_REGISTRY = {..., "name": "name"}
```

:param name: name that the alias refers to
:param alias: single alias or list of aliases that
refer to the name, defaults to None
"""
if alias is not None:
alias = alias if isinstance(alias, list) else [alias]
else:
alias = []
alias.append(name)
alias = _add_hyphen_underscores_variants(alias)

for alias_name in alias:
_ALIAS_REGISTRY[alias_name] = name
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved


def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
# import the given module path and try to get the value_name if it is included
# in the module
Expand Down Expand Up @@ -250,3 +323,13 @@ def _validate_subclass(parent_class: Type, child_class: Type):
f"class {child_class} is not a subclass of the class it is "
f"registered for: {parent_class}."
)


def _add_hyphen_underscores_variants(alias: List[str]) -> List[str]:
# add variants of the alias with hyphens and underscores
new_alias = []
for alias_name in alias:
new_alias.append(alias_name)
new_alias.append(alias_name.replace("_", "-"))
new_alias.append(alias_name.replace("-", "_"))
return list(set(new_alias))
76 changes: 53 additions & 23 deletions tests/sparsezoo/utils/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,71 @@

import pytest

from sparsezoo.utils.registry import RegistryMixin
from sparsezoo.utils.registry import _ALIAS_REGISTRY, _REGISTRY, RegistryMixin


def test_registery_flow_single():
@pytest.fixture()
def foo():
class Foo(RegistryMixin):
pass

@Foo.register()
class Foo1(Foo):
pass
yield Foo
_ALIAS_REGISTRY.clear()
_REGISTRY.clear()

assert {"Foo1"} == set(Foo.registered_names())

@Foo.register(name="name_2")
class Foo2(Foo):
pass
class TestFooRegistry:
def test_single_item(self, foo):
@foo.register()
class Foo1(foo):
pass

assert {"Foo1", "name_2"} == set(Foo.registered_names())
assert {"Foo1"} == set(foo.registered_names())
assert set() == set(foo.registered_aliases())

@Foo.register(name=["name_3", "name_4"])
class Foo3(Foo):
pass
def test_single_item_custom_name(self, foo):
@foo.register(name="name_2")
class Foo1(foo):
pass

assert {"Foo1", "name_2", "name_3", "name_4"} == set(Foo.registered_names())
assert {"name_2"} == set(foo.registered_names())
assert {"name-2"} == set(foo.registered_aliases())

with pytest.raises(KeyError):
Foo.get_value_from_registry("Foo2")
def test_alias(self, foo):
@foo.register(alias=["name-3", "name_4"])
class Foo1(foo):
pass

assert Foo.get_value_from_registry("Foo1") is Foo1
assert isinstance(Foo.load_from_registry("name_2"), Foo2)
assert (
Foo.get_value_from_registry("name_3")
is Foo3
is Foo.get_value_from_registry("name_4")
)
assert {"Foo1"} == set(foo.registered_names())
assert {"name-3", "name-4", "name_3", "name_4"} == set(foo.registered_aliases())

def test_alias_with_custom_name(self, foo):
@foo.register(name="name_2", alias=["name-3", "name_4"])
class Foo1(foo):
pass

assert {"name_2"} == set(foo.registered_names())
assert {"name-3", "name-4", "name_3", "name_4", "name-2"} == set(
foo.registered_aliases()
)

def test_get_value_from_registry(self, foo):
@foo.register(alias=["name-3"])
class Foo1(foo):
pass

@foo.register()
class Foo2(foo):
pass

with pytest.raises(KeyError):
foo.get_value_from_registry("Foo3")

assert foo.get_value_from_registry("Foo1") is Foo1
assert isinstance(foo.load_from_registry("Foo2"), Foo2)
assert foo.get_value_from_registry("Foo2") is Foo2
assert foo.get_value_from_registry("name_3") is Foo1
assert foo.get_value_from_registry("name-3") is Foo1


def test_registry_flow_multiple():
Expand Down
Loading