diff --git a/t5x/checkpoint_utils.py b/t5x/checkpoint_utils.py index c43a8f016..d817b61af 100644 --- a/t5x/checkpoint_utils.py +++ b/t5x/checkpoint_utils.py @@ -251,6 +251,12 @@ def detect_checkpoint_type( return checkpoint_type +def _is_supported_empty_value(value: Any) -> bool: + if hasattr(ocp.type_handlers, 'is_supported_empty_aggregation_type'): + return ocp.type_handlers.is_supported_empty_aggregation_type(value) + return ocp.type_handlers.is_supported_empty_value(value) + + def get_restore_parameters( directory: epath.Path, structure: PyTree, @@ -280,7 +286,7 @@ def _get_param_info( name: str, meta_or_value: Union[Any, ocp.metadata.tree.ValueMetadataEntry], ) -> Union[ocp.type_handlers.ParamInfo, Any]: - if ocp.type_handlers.is_supported_empty_aggregation_type(meta_or_value): + if _is_supported_empty_value(meta_or_value): # Empty node, ParamInfo should not be returned. return meta_or_value elif not isinstance(meta_or_value, ocp.metadata.tree.ValueMetadataEntry):