diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 8eb5f60b9d8..5c647b572db 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -141,6 +141,23 @@ def flatten_dict( return dict(items) +def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]: + """Unflatten a dict with keys containing separators into a nested dict.""" + unflattened_dict: Dict[str, Any] = {} + separator: str = "." + + for key, value in flat_dict.items(): + parts = key.split(separator) + d = unflattened_dict + for part in parts[:-1]: + if part not in d: + d[part] = {} + d = d[part] + d[parts[-1]] = value + + return unflattened_dict + + def parse_config_args( config: Optional[List[str]], separator: str = ",", diff --git a/src/py/flwr/common/config_test.py b/src/py/flwr/common/config_test.py index 52dcc0f9121..0e6a5bb8cb9 100644 --- a/src/py/flwr/common/config_test.py +++ b/src/py/flwr/common/config_test.py @@ -30,6 +30,7 @@ get_project_config, get_project_dir, parse_config_args, + unflatten_dict, ) # Mock constants @@ -229,6 +230,13 @@ def test_flatten_dict() -> None: assert flatten_dict(raw_dict) == expected +def test_unflatten_dict() -> None: + """Test unflatten_dict with a flat dictionary.""" + raw_dict = {"a.b.c": "d", "e": "f"} + expected = {"a": {"b": {"c": "d"}}, "e": "f"} + assert unflatten_dict(raw_dict) == expected + + def test_parse_config_args_none() -> None: """Test parse_config_args with None as input.""" assert not parse_config_args(None)