From 05c0ae54d72cfecc0b189b4d4b47e82547d05bb5 Mon Sep 17 00:00:00 2001 From: Damian Date: Tue, 5 Dec 2023 13:20:25 +0000 Subject: [PATCH 1/7] initial commit --- src/sparsezoo/utils/registry.py | 109 +++++++++++++++++++++---- tests/sparsezoo/utils/test_registry.py | 76 +++++++++++------ 2 files changed, 147 insertions(+), 38 deletions(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index bb37dd08..5040bfc0 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -27,9 +27,11 @@ "register", "get_from_registry", "registered_names", + "registered_aliases", ] +_ALIAS_REGISTRY: Dict[str, str] = defaultdict(dict) _REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict) @@ -82,7 +84,9 @@ class Cifar(Dataset): registry_requires_subclass: bool = False @classmethod - def register(cls, name: Union[List[str], str, None] = None): + def register( + cls, name: Optional[str] = None, alias: Union[List[str], str, None] = None + ): """ Decorator for registering a value (ie class or function) wrapped by this decorator to the base class (class that .register is called from) @@ -93,28 +97,30 @@ def register(cls, name: Union[List[str], str, None] = None): """ def decorator(value: Any): - cls.register_value(value, name=name) + cls.register_value(value, name=name, alias=alias) return value return decorator @classmethod - def register_value(cls, value: Any, name: Union[List[str], str, None] = None): + def register_value( + cls, value: Any, name: str, alias: Union[str, List[str], None] = None + ): """ Registers the given value to the class `.register_value` is called from :param value: value to register - :param name: name or list of names to register the wrapped value as, + :param name: name to register the wrapped value as, defaults to value.__name__ + :param alias: alias or list of aliases to register the wrapped value as, + defaults to None """ - names = name if isinstance(name, list) else [name] - - for name in names: - register( - parent_class=cls, - value=value, - name=name, - require_subclass=cls.registry_requires_subclass, - ) + register( + parent_class=cls, + value=value, + name=name, + alias=alias, + require_subclass=cls.registry_requires_subclass, + ) @classmethod def load_from_registry(cls, name: str, **constructor_kwargs) -> object: @@ -149,17 +155,27 @@ def registered_names(cls) -> List[str]: """ return registered_names(cls) + @classmethod + def registered_aliases(cls) -> List[str]: + """ + :return: list of all aliases registered to this class + """ + return registered_aliases(cls) + def register( parent_class: Type, value: Any, name: Optional[str] = None, + alias: Union[List[str], str, None] = None, require_subclass: bool = False, ): """ :param parent_class: class to register the name under :param value: the value to register :param name: name to register the wrapped value as, defaults to value.__name__ + :param alias: alias or list of aliases to register the wrapped value as, + defaults to None :param require_subclass: require that value is a subclass of the class this method is called from """ @@ -167,12 +183,15 @@ def register( # default name name = value.__name__ + register_alias(name=name, alias=alias) + if require_subclass: _validate_subclass(parent_class, value) if name in _REGISTRY[parent_class]: # name already exists - raise error if two different values are attempting # to share the same name + name = _ALIAS_REGISTRY[name] registered_value = _REGISTRY[parent_class][name] if registered_value is not value: raise RuntimeError( @@ -201,12 +220,17 @@ def get_from_registry( retrieved_value = _import_and_get_value_from_module(module_path, value_name) else: # look up name in registry + + name = _ALIAS_REGISTRY.get(name) retrieved_value = _REGISTRY[parent_class].get(name) + if retrieved_value is None: raise KeyError( - f"Unable to find {name} registered under type {parent_class}. " + f"Unable to find {name} registered under type {parent_class}.\n" f"Registered values for {parent_class}: " - f"{registered_names(parent_class)}" + f"{registered_names(parent_class)}\n" + f"Registered aliases for {parent_class}: " + f"{registered_aliases(parent_class)}" ) if require_subclass: @@ -223,6 +247,51 @@ def registered_names(parent_class: Type) -> List[str]: return list(_REGISTRY[parent_class].keys()) +def registered_aliases(parent_class: Type) -> List[str]: + """ + :param parent_class: class to look up the registry of + :return: all aliases registered to the given class + """ + alias_keys = set(_ALIAS_REGISTRY.keys()) + names_keys = set(_REGISTRY[parent_class].keys()) + return list(alias_keys.difference(names_keys)) + + +def register_alias(name: str, alias: Union[str, List[str], None] = None): + """ + Updates the mapping from the alias(es) to the given name. + If the alias is None, the name is used as the alias. + + Note, that the number of actual added alias(es) will be potentially + larger than the number of given aliases, since the function adds + variants of the alias with hyphens and underscores. + + Examples: + + If alias = ["alias1", "alias2"] + ``` + _ALIAS_REGISTRY = {..., "alias1": "name", "alias2": "name"} + ``` + If alias = None + ``` + _ALIAS_REGISTRY = {..., "name": "name"} + ``` + + :param name: name that the alias refers to + :param alias: single alias or list of aliases that + refer to the name, defaults to None + """ + if alias is not None: + alias = alias if isinstance(alias, list) else [alias] + else: + alias = [] + alias.append(name) + alias = _add_hyphen_underscores_variants(alias) + + for alias_name in alias: + _ALIAS_REGISTRY[alias_name] = name + + def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any: # import the given module path and try to get the value_name if it is included # in the module @@ -250,3 +319,13 @@ def _validate_subclass(parent_class: Type, child_class: Type): f"class {child_class} is not a subclass of the class it is " f"registered for: {parent_class}." ) + + +def _add_hyphen_underscores_variants(alias: List[str]) -> List[str]: + # add variants of the alias with hyphens and underscores + new_alias = [] + for alias_name in alias: + new_alias.append(alias_name) + new_alias.append(alias_name.replace("_", "-")) + new_alias.append(alias_name.replace("-", "_")) + return list(set(new_alias)) diff --git a/tests/sparsezoo/utils/test_registry.py b/tests/sparsezoo/utils/test_registry.py index c8fcab26..73952da5 100644 --- a/tests/sparsezoo/utils/test_registry.py +++ b/tests/sparsezoo/utils/test_registry.py @@ -14,41 +14,71 @@ import pytest -from sparsezoo.utils.registry import RegistryMixin +from sparsezoo.utils.registry import _ALIAS_REGISTRY, _REGISTRY, RegistryMixin -def test_registery_flow_single(): +@pytest.fixture() +def foo(): class Foo(RegistryMixin): pass - @Foo.register() - class Foo1(Foo): - pass + yield Foo + _ALIAS_REGISTRY.clear() + _REGISTRY.clear() - assert {"Foo1"} == set(Foo.registered_names()) - @Foo.register(name="name_2") - class Foo2(Foo): - pass +class TestFooRegistry: + def test_single_item(self, foo): + @foo.register() + class Foo1(foo): + pass - assert {"Foo1", "name_2"} == set(Foo.registered_names()) + assert {"Foo1"} == set(foo.registered_names()) + assert set() == set(foo.registered_aliases()) - @Foo.register(name=["name_3", "name_4"]) - class Foo3(Foo): - pass + def test_single_item_custom_name(self, foo): + @foo.register(name="name_2") + class Foo1(foo): + pass - assert {"Foo1", "name_2", "name_3", "name_4"} == set(Foo.registered_names()) + assert {"name_2"} == set(foo.registered_names()) + assert {"name-2"} == set(foo.registered_aliases()) - with pytest.raises(KeyError): - Foo.get_value_from_registry("Foo2") + def test_alias(self, foo): + @foo.register(alias=["name-3", "name_4"]) + class Foo1(foo): + pass - assert Foo.get_value_from_registry("Foo1") is Foo1 - assert isinstance(Foo.load_from_registry("name_2"), Foo2) - assert ( - Foo.get_value_from_registry("name_3") - is Foo3 - is Foo.get_value_from_registry("name_4") - ) + assert {"Foo1"} == set(foo.registered_names()) + assert {"name-3", "name-4", "name_3", "name_4"} == set(foo.registered_aliases()) + + def test_alias_with_custom_name(self, foo): + @foo.register(name="name_2", alias=["name-3", "name_4"]) + class Foo1(foo): + pass + + assert {"name_2"} == set(foo.registered_names()) + assert {"name-3", "name-4", "name_3", "name_4", "name-2"} == set( + foo.registered_aliases() + ) + + def test_get_value_from_registry(self, foo): + @foo.register(alias=["name-3"]) + class Foo1(foo): + pass + + @foo.register() + class Foo2(foo): + pass + + with pytest.raises(KeyError): + foo.get_value_from_registry("Foo3") + + assert foo.get_value_from_registry("Foo1") is Foo1 + assert isinstance(foo.load_from_registry("Foo2"), Foo2) + assert foo.get_value_from_registry("Foo2") is Foo2 + assert foo.get_value_from_registry("name_3") is Foo1 + assert foo.get_value_from_registry("name-3") is Foo1 def test_registry_flow_multiple(): From cbe583f5167c2ca4d7a7cc385b50bd571d7b47c0 Mon Sep 17 00:00:00 2001 From: Damian Date: Tue, 5 Dec 2023 13:25:44 +0000 Subject: [PATCH 2/7] add docstrings --- src/sparsezoo/utils/registry.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index 5040bfc0..82f5a973 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -67,7 +67,7 @@ class Cifar(Dataset): pass # register with multiple aliases - @Dataset.register(name=["cifar-10-dataset", "cifar-100-dataset"]) + @Dataset.register(alias=["cifar-10-dataset", "cifar_100_dataset"]) class Cifar(Dataset): pass @@ -77,6 +77,13 @@ class Cifar(Dataset): # load from custom file that implements a dataset mnist = Dataset.load_from_registry("/path/to/mnnist_dataset.py:MnistDataset") ``` + + Note, that any name or alias that is being registered, will be also recognized + when all hyphens are replaced with underscores and vice versa. + For example, if a class is registered: + - as "cifar-10-dataset", it will be also recognized as "cifar_10_dataset" + - as "cifar_10_dataset", it will be also recognized as "cifar-10-dataset" + etc. """ # set to True in child class to add check that registered/retrieved values From 3a5b7a51d69cce47aa38c69efb60d13c4bfb2743 Mon Sep 17 00:00:00 2001 From: Damian Date: Tue, 5 Dec 2023 13:32:29 +0000 Subject: [PATCH 3/7] simplify --- src/sparsezoo/utils/registry.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index 82f5a973..a3fe4543 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -198,7 +198,6 @@ def register( if name in _REGISTRY[parent_class]: # name already exists - raise error if two different values are attempting # to share the same name - name = _ALIAS_REGISTRY[name] registered_value = _REGISTRY[parent_class][name] if registered_value is not value: raise RuntimeError( @@ -227,10 +226,8 @@ def get_from_registry( retrieved_value = _import_and_get_value_from_module(module_path, value_name) else: # look up name in registry - name = _ALIAS_REGISTRY.get(name) retrieved_value = _REGISTRY[parent_class].get(name) - if retrieved_value is None: raise KeyError( f"Unable to find {name} registered under type {parent_class}.\n" From 8dad60e489c2f98b97472d5f276e237b3491fc12 Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Thu, 7 Dec 2023 11:24:58 +0000 Subject: [PATCH 4/7] hardening --- src/sparsezoo/utils/registry.py | 9 +++- tests/sparsezoo/utils/test_registry.py | 66 +++++++++++++++++++------- 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index a3fe4543..9bb34b35 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -225,8 +225,9 @@ def get_from_registry( module_path, value_name = name.split(":") retrieved_value = _import_and_get_value_from_module(module_path, value_name) else: - # look up name in registry + # look up name in alias registry name = _ALIAS_REGISTRY.get(name) + # look up name in registry retrieved_value = _REGISTRY[parent_class].get(name) if retrieved_value is None: raise KeyError( @@ -293,6 +294,12 @@ def register_alias(name: str, alias: Union[str, List[str], None] = None): alias = _add_hyphen_underscores_variants(alias) for alias_name in alias: + if alias_name in _ALIAS_REGISTRY: + raise KeyError( + f"Attempting to register alias {alias_name} as {name} " + f"however {alias_name} has already been registered as " + f"{_ALIAS_REGISTRY[alias_name]}" + ) _ALIAS_REGISTRY[alias_name] = name diff --git a/tests/sparsezoo/utils/test_registry.py b/tests/sparsezoo/utils/test_registry.py index 73952da5..d4b64b1f 100644 --- a/tests/sparsezoo/utils/test_registry.py +++ b/tests/sparsezoo/utils/test_registry.py @@ -27,7 +27,17 @@ class Foo(RegistryMixin): _REGISTRY.clear() -class TestFooRegistry: +@pytest.fixture() +def bar(): + class Bar(RegistryMixin): + pass + + yield Bar + _ALIAS_REGISTRY.clear() + _REGISTRY.clear() + + +class TestRegistryFlowSingle: def test_single_item(self, foo): @foo.register() class Foo1(foo): @@ -52,6 +62,31 @@ class Foo1(foo): assert {"Foo1"} == set(foo.registered_names()) assert {"name-3", "name-4", "name_3", "name_4"} == set(foo.registered_aliases()) + def test_key_error_on_duplicate_alias(self, foo): + @foo.register(alias=["name-3"]) + class Foo1(foo): + pass + + with pytest.raises(KeyError): + + @foo.register(alias=["name-3"]) + class Foo2(foo): + pass + + with pytest.raises(KeyError): + + @foo.register(alias=["name_3"]) + class Foo3(foo): + pass + + def test_alias_equal_name(self, foo): + @foo.register(name="name-3", alias=["name-3"]) + class Foo1(foo): + pass + + assert {"name-3"} == set(foo.registered_names()) + assert {"name_3"} == set(foo.registered_aliases()) + def test_alias_with_custom_name(self, foo): @foo.register(name="name_2", alias=["name-3", "name_4"]) class Foo1(foo): @@ -81,26 +116,21 @@ class Foo2(foo): assert foo.get_value_from_registry("name-3") is Foo1 -def test_registry_flow_multiple(): - class Foo(RegistryMixin): - pass - - class Bar(RegistryMixin): - pass - - @Foo.register() - class Foo1(Foo): - pass +class TestRegistryFlowMultiple: + def test_single_item(self, foo, bar): + @foo.register() + class Foo1(foo): + pass - @Bar.register() - class Bar1(Bar): - pass + @bar.register() + class Bar1(bar): + pass - assert ["Foo1"] == Foo.registered_names() - assert ["Bar1"] == Bar.registered_names() + assert ["Foo1"] == foo.registered_names() + assert ["Bar1"] == bar.registered_names() - assert Foo.get_value_from_registry("Foo1") is Foo1 - assert Bar.get_value_from_registry("Bar1") is Bar1 + assert foo.get_value_from_registry("Foo1") is Foo1 + assert bar.get_value_from_registry("Bar1") is Bar1 def test_registry_requires_subclass(): From eef0fb1bd8fc7552a6cfe6d7b4f6548a41c21fc6 Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Thu, 21 Dec 2023 12:20:21 +0000 Subject: [PATCH 5/7] refactor --- src/sparsezoo/utils/registry.py | 89 ++++++++++++++------------ tests/sparsezoo/utils/test_registry.py | 50 ++++++++++----- 2 files changed, 82 insertions(+), 57 deletions(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index 9bb34b35..1c41f1cf 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -28,13 +28,31 @@ "get_from_registry", "registered_names", "registered_aliases", + "standardize_lookup_name", ] -_ALIAS_REGISTRY: Dict[str, str] = defaultdict(dict) +_ALIAS_REGISTRY: Dict[Type, Dict[str, str]] = defaultdict(dict) _REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict) +def standardize_lookup_name(name: str) -> str: + """ + Standardize the given name for lookup in the registry. + This will replace all underscores and spaces with hyphens and + convert the name to lowercase. + + example: + ``` + standardize_lookup_name("foo_bar baz") == "foo-bar-baz" + ``` + + :param name: name to standardize + :return: standardized name + """ + return name.replace("_", "-").replace(" ", "-") + + class RegistryMixin: """ Universal registry to support registration and loading of child classes and plugins @@ -66,10 +84,16 @@ class ImageNetDataset(Dataset): class Cifar(Dataset): pass + Note: the name will be standardized for lookup in the registry. + For example, if a class is registered as "cifar_dataset" or + "cifar dataset", it will be stored as "cifar-dataset". The user + will be able to load the class with any of the three name variants. + # register with multiple aliases @Dataset.register(alias=["cifar-10-dataset", "cifar_100_dataset"]) class Cifar(Dataset): pass + Note: aliases will NOT be standardized for lookup in the registry. # load as "cifar-dataset" cifar = Dataset.load_from_registry("cifar-dataset") @@ -77,13 +101,6 @@ class Cifar(Dataset): # load from custom file that implements a dataset mnist = Dataset.load_from_registry("/path/to/mnnist_dataset.py:MnistDataset") ``` - - Note, that any name or alias that is being registered, will be also recognized - when all hyphens are replaced with underscores and vice versa. - For example, if a class is registered: - - as "cifar-10-dataset", it will be also recognized as "cifar_10_dataset" - - as "cifar_10_dataset", it will be also recognized as "cifar-10-dataset" - etc. """ # set to True in child class to add check that registered/retrieved values @@ -100,6 +117,8 @@ def register( :param name: name or list of names to register the wrapped value as, defaults to value.__name__ + :param alias: alias or list of aliases to register the wrapped value as, + defaults to None :return: register decorator """ @@ -190,7 +209,8 @@ def register( # default name name = value.__name__ - register_alias(name=name, alias=alias) + name = standardize_lookup_name(name) + register_alias(name=name, alias=alias, parent_class=parent_class) if require_subclass: _validate_subclass(parent_class, value) @@ -219,6 +239,7 @@ def get_from_registry( :return: value from retrieved the registry for the given name, raises error if not found """ + name = standardize_lookup_name(name) if ":" in name: # user specifying specific module to load and value to import @@ -226,7 +247,7 @@ def get_from_registry( retrieved_value = _import_and_get_value_from_module(module_path, value_name) else: # look up name in alias registry - name = _ALIAS_REGISTRY.get(name) + name = _ALIAS_REGISTRY[parent_class].get(name) # look up name in registry retrieved_value = _REGISTRY[parent_class].get(name) if retrieved_value is None: @@ -257,32 +278,23 @@ def registered_aliases(parent_class: Type) -> List[str]: :param parent_class: class to look up the registry of :return: all aliases registered to the given class """ - alias_keys = set(_ALIAS_REGISTRY.keys()) - names_keys = set(_REGISTRY[parent_class].keys()) - return list(alias_keys.difference(names_keys)) + registered_aliases_plus_names = list(_ALIAS_REGISTRY[parent_class].keys()) + registered_aliases = list( + set(registered_aliases_plus_names) - set(registered_names(parent_class)) + ) + return registered_aliases -def register_alias(name: str, alias: Union[str, List[str], None] = None): +def register_alias( + name: str, parent_class: Type, alias: Union[str, List[str], None] = None +): """ Updates the mapping from the alias(es) to the given name. If the alias is None, the name is used as the alias. - - Note, that the number of actual added alias(es) will be potentially - larger than the number of given aliases, since the function adds - variants of the alias with hyphens and underscores. - - Examples: - - If alias = ["alias1", "alias2"] - ``` - _ALIAS_REGISTRY = {..., "alias1": "name", "alias2": "name"} - ``` - If alias = None - ``` - _ALIAS_REGISTRY = {..., "name": "name"} ``` :param name: name that the alias refers to + :param parent_class: class that the name is registered under :param alias: single alias or list of aliases that refer to the name, defaults to None """ @@ -290,17 +302,22 @@ def register_alias(name: str, alias: Union[str, List[str], None] = None): alias = alias if isinstance(alias, list) else [alias] else: alias = [] + + if name in alias: + raise KeyError( + f"Attempting to register alias {name}, " + f"that is identical to the standardized name: {name}." + ) alias.append(name) - alias = _add_hyphen_underscores_variants(alias) for alias_name in alias: - if alias_name in _ALIAS_REGISTRY: + if alias_name in _ALIAS_REGISTRY[parent_class]: raise KeyError( f"Attempting to register alias {alias_name} as {name} " f"however {alias_name} has already been registered as " f"{_ALIAS_REGISTRY[alias_name]}" ) - _ALIAS_REGISTRY[alias_name] = name + _ALIAS_REGISTRY[parent_class][alias_name] = name def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any: @@ -330,13 +347,3 @@ def _validate_subclass(parent_class: Type, child_class: Type): f"class {child_class} is not a subclass of the class it is " f"registered for: {parent_class}." ) - - -def _add_hyphen_underscores_variants(alias: List[str]) -> List[str]: - # add variants of the alias with hyphens and underscores - new_alias = [] - for alias_name in alias: - new_alias.append(alias_name) - new_alias.append(alias_name.replace("_", "-")) - new_alias.append(alias_name.replace("-", "_")) - return list(set(new_alias)) diff --git a/tests/sparsezoo/utils/test_registry.py b/tests/sparsezoo/utils/test_registry.py index d4b64b1f..5e3fab29 100644 --- a/tests/sparsezoo/utils/test_registry.py +++ b/tests/sparsezoo/utils/test_registry.py @@ -39,6 +39,7 @@ class Bar(RegistryMixin): class TestRegistryFlowSingle: def test_single_item(self, foo): + # register Foo1 in registry foo under the default name Foo1 @foo.register() class Foo1(foo): pass @@ -47,22 +48,32 @@ class Foo1(foo): assert set() == set(foo.registered_aliases()) def test_single_item_custom_name(self, foo): + # register Foo1 in registry foo under the name name_2 + # (this will turn name_2 into name-2, as we standardize names) @foo.register(name="name_2") class Foo1(foo): pass - assert {"name_2"} == set(foo.registered_names()) - assert {"name-2"} == set(foo.registered_aliases()) + assert {"name-2"} == set(foo.registered_names()) + assert set() == set(foo.registered_aliases()) + # because we registered under a custom name, we can't + # use the original object name + with pytest.raises(KeyError): + foo.get_value_from_registry("Foo1") def test_alias(self, foo): + # register Foo1 in registry foo under the default name Foo1 + # and alias names name-3 and name_4 @foo.register(alias=["name-3", "name_4"]) class Foo1(foo): pass assert {"Foo1"} == set(foo.registered_names()) - assert {"name-3", "name-4", "name_3", "name_4"} == set(foo.registered_aliases()) + assert {"name-3", "name_4"} == set(foo.registered_aliases()) def test_key_error_on_duplicate_alias(self, foo): + # once we register an object under one alias, we can't + # register it under the same alias once again @foo.register(alias=["name-3"]) class Foo1(foo): pass @@ -73,29 +84,34 @@ class Foo1(foo): class Foo2(foo): pass + def test_key_error_alias_equal_name(self, foo): + # once we register the object under name-3 + # (not name_3, as we standardize names), we can't + # register it under the same alias + @foo.register(name="name_3") + class Foo1(foo): + pass + with pytest.raises(KeyError): - @foo.register(alias=["name_3"]) - class Foo3(foo): + @foo.register(alias=["name-3"]) + class Foo2(foo): pass - def test_alias_equal_name(self, foo): - @foo.register(name="name-3", alias=["name-3"]) - class Foo1(foo): - pass + def test_key_error_alias_equal_name_simultaneously(self, foo): + with pytest.raises(KeyError): - assert {"name-3"} == set(foo.registered_names()) - assert {"name_3"} == set(foo.registered_aliases()) + @foo.register(name="name_3", alias=["name-3"]) + class Foo2(foo): + pass def test_alias_with_custom_name(self, foo): @foo.register(name="name_2", alias=["name-3", "name_4"]) class Foo1(foo): pass - assert {"name_2"} == set(foo.registered_names()) - assert {"name-3", "name-4", "name_3", "name_4", "name-2"} == set( - foo.registered_aliases() - ) + assert {"name-2"} == set(foo.registered_names()) + assert {"name-3", "name_4"} == set(foo.registered_aliases()) def test_get_value_from_registry(self, foo): @foo.register(alias=["name-3"]) @@ -112,8 +128,10 @@ class Foo2(foo): assert foo.get_value_from_registry("Foo1") is Foo1 assert isinstance(foo.load_from_registry("Foo2"), Foo2) assert foo.get_value_from_registry("Foo2") is Foo2 - assert foo.get_value_from_registry("name_3") is Foo1 assert foo.get_value_from_registry("name-3") is Foo1 + assert foo.get_value_from_registry("name_3") is Foo1 + assert foo.get_value_from_registry("name 3") is Foo1 + assert isinstance(foo.load_from_registry("name_3"), Foo1) class TestRegistryFlowMultiple: From b363dea51bc4bc9abcbe5708449714ebbaa19319 Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Wed, 17 Jan 2024 12:15:31 +0000 Subject: [PATCH 6/7] format registry lookup strings to be lowercases --- src/sparsezoo/utils/registry.py | 4 ++-- tests/sparsezoo/utils/test_registry.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index 1c41f1cf..b1ed8efa 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -44,13 +44,13 @@ def standardize_lookup_name(name: str) -> str: example: ``` - standardize_lookup_name("foo_bar baz") == "foo-bar-baz" + standardize_lookup_name("Foo_bar baz") == "foo-bar-baz" ``` :param name: name to standardize :return: standardized name """ - return name.replace("_", "-").replace(" ", "-") + return name.replace("_", "-").replace(" ", "-").lower() class RegistryMixin: diff --git a/tests/sparsezoo/utils/test_registry.py b/tests/sparsezoo/utils/test_registry.py index 5e3fab29..53ba7227 100644 --- a/tests/sparsezoo/utils/test_registry.py +++ b/tests/sparsezoo/utils/test_registry.py @@ -44,7 +44,7 @@ def test_single_item(self, foo): class Foo1(foo): pass - assert {"Foo1"} == set(foo.registered_names()) + assert {"foo1"} == set(foo.registered_names()) assert set() == set(foo.registered_aliases()) def test_single_item_custom_name(self, foo): @@ -68,7 +68,7 @@ def test_alias(self, foo): class Foo1(foo): pass - assert {"Foo1"} == set(foo.registered_names()) + assert {"foo1"} == set(foo.registered_names()) assert {"name-3", "name_4"} == set(foo.registered_aliases()) def test_key_error_on_duplicate_alias(self, foo): @@ -144,11 +144,11 @@ class Foo1(foo): class Bar1(bar): pass - assert ["Foo1"] == foo.registered_names() - assert ["Bar1"] == bar.registered_names() + assert ["foo1"] == foo.registered_names() + assert ["bar1"] == bar.registered_names() - assert foo.get_value_from_registry("Foo1") is Foo1 - assert bar.get_value_from_registry("Bar1") is Bar1 + assert foo.get_value_from_registry("foo1") is Foo1 + assert bar.get_value_from_registry("bar1") is Bar1 def test_registry_requires_subclass(): From d34ad2ac9d1cb1c7c28d1761f39c1c8b9870d692 Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Thu, 18 Jan 2024 11:03:38 +0000 Subject: [PATCH 7/7] standardise aliases --- src/sparsezoo/utils/registry.py | 12 ++++++++++++ tests/sparsezoo/utils/test_registry.py | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index b1ed8efa..d85c83c2 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -53,6 +53,17 @@ def standardize_lookup_name(name: str) -> str: return name.replace("_", "-").replace(" ", "-").lower() +def standardize_alias_name( + name: Union[None, str, List[str]] +) -> Union[None, str, List[str]]: + if name is None: + return None + elif isinstance(name, str): + return standardize_lookup_name(name) + else: # isinstance(name, list) + return [standardize_lookup_name(n) for n in name] + + class RegistryMixin: """ Universal registry to support registration and loading of child classes and plugins @@ -210,6 +221,7 @@ def register( name = value.__name__ name = standardize_lookup_name(name) + alias = standardize_alias_name(alias) register_alias(name=name, alias=alias, parent_class=parent_class) if require_subclass: diff --git a/tests/sparsezoo/utils/test_registry.py b/tests/sparsezoo/utils/test_registry.py index 53ba7227..e22650ec 100644 --- a/tests/sparsezoo/utils/test_registry.py +++ b/tests/sparsezoo/utils/test_registry.py @@ -69,7 +69,7 @@ class Foo1(foo): pass assert {"foo1"} == set(foo.registered_names()) - assert {"name-3", "name_4"} == set(foo.registered_aliases()) + assert {"name-3", "name-4"} == set(foo.registered_aliases()) def test_key_error_on_duplicate_alias(self, foo): # once we register an object under one alias, we can't @@ -111,7 +111,7 @@ class Foo1(foo): pass assert {"name-2"} == set(foo.registered_names()) - assert {"name-3", "name_4"} == set(foo.registered_aliases()) + assert {"name-3", "name-4"} == set(foo.registered_aliases()) def test_get_value_from_registry(self, foo): @foo.register(alias=["name-3"])