From 97cf12701fc45785818f0835032766d05ef3e52e Mon Sep 17 00:00:00 2001 From: jernejfrank Date: Fri, 11 Oct 2024 21:04:22 +0100 Subject: [PATCH] Restrict pipe_input API to single on_input target The implementation allows for global string input for a target. - Functionality is there for multiple global and local parameter targets, but we shortcircuit for the moment with an error if the arg is not a string. - At the moment all transforms in the pipe get applied to the same argument. - Test and docs have commented out the multiple parameter cases for easy enabling should we want this. - Added URL of the GitHub issue tracking interest to enable multiple params --- hamilton/function_modifiers/macros.py | 226 ++++++----- tests/function_modifiers/test_macros.py | 517 ++++++++++++------------ tests/resources/pipe_input.py | 99 ++--- 3 files changed, 446 insertions(+), 396 deletions(-) diff --git a/hamilton/function_modifiers/macros.py b/hamilton/function_modifiers/macros.py index 21f6242ea..3fb98baae 100644 --- a/hamilton/function_modifiers/macros.py +++ b/hamilton/function_modifiers/macros.py @@ -346,6 +346,7 @@ def __init__( :param _name: Name of the node to be created :param _namespace: Namespace of the node to be created -- currently only single-level namespaces are supported :param _target: Selects which target nodes it will be appended onto. Default None gets resolved on decorator level. + Specifically, pipe_input would use the first parameter and pipe_output / mutate would apply it to all sink nodes. :param kwargs: Kwargs (**kwargs) to pass to the function """ @@ -475,25 +476,26 @@ def named(self, name: str, namespace: NamespaceType = ...) -> "Applicable": # on_input / on_output are the same but here for naming convention # I know there is a way to dynamically resolve this to revert to a common function # just can't remember it now or find it online... - def on_input(self, target: base.TargetType) -> "Applicable": - """Add Target on a single function level. - - This determines to which node(s) it will applies. Should match the same naming convention - as the NodeTransorfmLifecycle child class (for example NodeTransformer). - - :param target: Which node(s) to apply on top of - :return: The Applicable with specified target - """ - return Applicable( - fn=self.fn, - _resolvers=self.resolvers, - _name=self.name, - _namespace=self.namespace, - _target=target if target is not None else self.target, - args=self.args, - kwargs=self.kwargs, - target_fn=self.target_fn, - ) + # TODO: adding the option to select target parameter for each transform + # def on_input(self, target: base.TargetType) -> "Applicable": + # """Add Target on a single function level. + + # This determines to which node(s) it will applies. Should match the same naming convention + # as the NodeTransorfmLifecycle child class (for example NodeTransformer). + + # :param target: Which node(s) to apply on top of + # :return: The Applicable with specified target + # """ + # return Applicable( + # fn=self.fn, + # _resolvers=self.resolvers, + # _name=self.name, + # _namespace=self.namespace, + # _target=target if target is not None else self.target, + # args=self.args, + # kwargs=self.kwargs, + # target_fn=self.target_fn, + # ) def on_output(self, target: base.TargetType) -> "Applicable": """Add Target on a single function level. @@ -649,18 +651,19 @@ def step( return Applicable(fn=fn, _resolvers=[], args=args, kwargs=kwargs) -class MissingTargetError(Exception): - """When setting target make sure it is clear which transform targets which node. +# TODO: In case of multiple parameter targets we want to safeguard that it is a clear target distribution +# class MissingTargetError(Exception): +# """When setting target make sure it is clear which transform targets which node. - This is a safeguard, because the default behavior may not apply if targets are partially set - and we do not want to make assumptions what the user meant. - """ +# This is a safeguard, because the default behavior may not apply if targets are partially set +# and we do not want to make assumptions what the user meant. +# """ - pass +# pass class pipe_input(base.NodeInjector): - """Running a series of transformation on the input of the function. + """Running a series of transformations on the input of the function. To demonstrate the rules for chaining nodes, we'll be using the following example. This is using primitives to demonstrate, but as hamilton is just functions of any python objects, this works perfectly with @@ -737,9 +740,9 @@ def final_result(upstream_int: int) -> int: One has three ways to tune the shape/implementation of the subsequent nodes: - 1. ``when``/``when_not``/``when_in``/``when_not_in`` -- these are used to filter the application of the function. This is valuable to reflect - if/else conditions in the structure of the DAG, pulling it out of functions, rather than buried within the logic itself. It is functionally - equivalent to ``@config.when``. + 1. ``when``/``when_not``/``when_in``/``when_not_in`` -- these are used to filter the application of the function. + This is valuable to reflect if/else conditions in the structure of the DAG, pulling it out of functions, rather + than buried within the logic itself. It is functionally equivalent to ``@config.when``. For instance, if you want to include a function in the chain only when a config parameter is set to a certain value, you can do: @@ -754,8 +757,8 @@ def final_result(upstream_int: int) -> int: This will only apply the first function when the config parameter ``foo`` is set to ``bar``, and the second when it is set to ``baz``. - 2. ``named`` -- this is used to name the node. This is useful if you want to refer to intermediate results. If this is left out, - hamilton will automatically name the functions in a globally unique manner. The names of + 2. ``named`` -- this is used to name the node. This is useful if you want to refer to intermediate results. + If this is left out, hamilton will automatically name the functions in a globally unique manner. The names of these functions will not necessarily be stable/guaranteed by the API, so if you want to refer to them, you should use ``named``. The default namespace will always be the name of the decorated function (which will be the last node in the chain). @@ -811,23 +814,11 @@ def final_result(upstream_int: int) -> int: In all likelihood, you should not be using this, and this is only here in case you want to expose a node for consumption/output later. Setting the namespace in individual nodes as well as in ``pipe_input`` is not yet supported. - 3. For extra control in case of multiple function arguments (parameters), we can also specify the target parameter that we wish to transform. - In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter only. If ``on_input`` is set for a specific transform - make sure the other ones are also set either through a global setting or individually, otherwise it is unclear which transforms target which parameters. - - The following applies *_add_one* to ``p1``, ``p3`` and *_add_two* to ``p2`` - - .. code-block:: python - - @pipe_input( - step(_add_one).on_input(["p1","p3"]) - step(_add_two, y=source("upstream_node")).on_input("p2") - ) - def final_result(p1: int, p2: int, p3: int) -> int: - return upstream_int + 3. ``on_input`` -- this selects which input we will run the pipeline on. + In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter. Let us know if you wish to expand to other use-cases. + You can track the progress on this topic via: https://github.com/DAGWorks-Inc/hamilton/issues/1177 - We can also do this on the global level to set for all transforms a target parameter. - The following would apply function *_add_one* and *_add_two* to ``p2`` + The following would apply function *_add_one* and *_add_two* to ``p2``: .. code-block:: python @@ -839,20 +830,39 @@ def final_result(p1: int, p2: int, p3: int) -> int: def final_result(p1: int, p2: int, p3: int) -> int: return upstream_int - Lastly, a mixture of global and local is possible, where the global selects the target parameters for - all transforms and we can select individual transforms to also target more parameters. - The following would apply function *_add_one* to all ``p1``, ``p2``, ``p3`` and *_add_two* also on ``p2`` + .. | + THIS IS COMMENTED OUT, I.E. SPHINX WILL NOT AUTODOC IT, HERE IN CASE WE ENABLE MULTIPLE PARAMETER TARGETS + For extra control in case of multiple function arguments (parameters), we can also specify the target parameter that we wish to transform. + In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter only. If ``on_input`` is set for a specific transform + make sure the other ones are also set either through a global setting or individually, otherwise it is unclear which transforms target which parameters. - .. code-block:: python + The following applies *_add_one* to ``p1``, ``p3`` and *_add_two* to ``p2`` - @pipe_input( - step(_add_one).on_input(["p1","p3"]) - step(_add_two, y=source("upstream_node")), - on_input = "p2" - ) - def final_result(p1: int, p2: int, p3: int) -> int: - return upstream_int + .. code-block:: python + + @pipe_input( + step(_add_one).on_input(["p1","p3"]) + step(_add_two, y=source("upstream_node")).on_input("p2") + ) + def final_result(p1: int, p2: int, p3: int) -> int: + return upstream_int + + We can also do this on the global level to set for all transforms a target parameter. + + Lastly, a mixture of global and local is possible, where the global selects the target parameters for + all transforms and we can select individual transforms to also target more parameters. + The following would apply function *_add_one* to all ``p1``, ``p2``, ``p3`` and *_add_two* also on ``p2`` + + .. code-block:: python + @pipe_input( + step(_add_one).on_input(["p1","p3"]) + step(_add_two, y=source("upstream_node")), + on_input = "p2" + ) + def final_result(p1: int, p2: int, p3: int) -> int: + return upstream_int + | replace:: \ """ def __init__( @@ -871,19 +881,27 @@ def __init__( :param collapse: Whether to collapse this into a single node. This is not currently supported. :param _chain: Whether to chain the first parameter. This is the only mode that is supported. Furthermore, this is not externally exposed. ``@flow`` will make use of this. """ + if on_input is not None: + if not isinstance(on_input, str): + raise NotImplementedError( + "on_input currently only supports a single target parameter specified by a string. " + "Please reach out if you want a more flexible option in the feature." + ) base.NodeTransformer._early_validate_target(target=on_input, allow_multiple=True) self.transforms = transforms self.collapse = collapse self.chain = _chain self.namespace = namespace + self.target = [on_input] - if isinstance(on_input, str): # have to do extra since strings are collections in python - self.target = [on_input] - elif isinstance(on_input, Collection): - self.target = on_input - else: - self.target = [on_input] + # TODO: for multiple target parameter case + # if isinstance(on_input, str): # have to do extra since strings are collections in python + # self.target = [on_input] + # elif isinstance(on_input, Collection): + # self.target = on_input + # else: + # self.target = [on_input] if self.collapse: raise NotImplementedError( @@ -909,35 +927,42 @@ def _distribute_transforms_to_parameters( selected_transforms = defaultdict(list) for param in params: - for transform in self.transforms: - target = transform.target - # In case there is no target set on applicable we assign global target - if target is None: - target = self.target - elif isinstance(target, str): # user selects single target via string - target = [target] - target.extend(self.target) - elif isinstance(target, Collection): # user inputs a list of targets - target.extend(self.target) - - if param in target: - selected_transforms[param].append(transform) + if param in self.target: + selected_transforms[param].extend(self.transforms) + # TODO: in case of multiple parameters we can set individual targets and resolve them here + # for transform in self.transforms: + # target = transform.target + # # In case there is no target set on applicable we assign global target + # if target is None: + # target = self.target + # elif isinstance(target, str): # user selects single target via string + # target = [target] + # target.extend(self.target) + # elif isinstance(target, Collection): # user inputs a list of targets + # target.extend(self.target) + + # if param in target: + # selected_transforms[param].append(transform) return selected_transforms - def _check_parameters_transforms_mapping_validity( + def _create_valid_parameters_transforms_mapping( self, mapping: Dict[str, List[Applicable]], fn: Callable, params: Dict[str, Type[Type]] ) -> Dict[str, List[Applicable]]: """Checks for a valid distribution of transforms to parameters.""" - if not mapping: - # This reverts back to legacy chaining through first parameter and checks first parameter - sig = inspect.signature(fn) - first_parameter = list(sig.parameters.values())[0].name + sig = inspect.signature(fn) + param_names = [] + for param in list(sig.parameters.values()): + param_names.append(param.name) # use the name of the parameter to determine the first node # Then wire them all through in order # if it resolves, great # if not, skip that, pointing to the previous # Create a node along the way + + if not mapping: + # This reverts back to legacy chaining through first parameter and checks first parameter + first_parameter = param_names[0] if first_parameter not in params: raise base.InvalidDecoratorException( f"Function: {fn.__name__} has a first parameter that is not a dependency. " @@ -945,18 +970,29 @@ def _check_parameters_transforms_mapping_validity( f"Thus it might not be compatible with some other decorators" ) mapping[first_parameter] = self.transforms - else: - # in case we set target this checks that each transform has at least one target parameter - transform_set = [] - for param in mapping: - transform_set.extend(mapping[param]) - transform_set = set(transform_set) - if len(transform_set) != len(self.transforms): - raise MissingTargetError( - "The on_input settings are unclear. Please make sure all transforms " - "either have specified individually or globally a target or there is " - "no on_input usage." - ) + # TODO: validate that all transforms have a target in case multiple parameters targeted + # else: + # # in case we set target this checks that each transform has at least one target parameter + # transform_set = [] + # for param in mapping: + # transform_set.extend(mapping[param]) + # transform_set = set(transform_set) + # if len(transform_set) != len(self.transforms): + # raise MissingTargetError( + # "The on_input settings are unclear. Please make sure all transforms " + # "either have specified individually or globally a target or there is " + # "no on_input usage." + # ) + + # similar to above we check that the target parameter is among the actual function parameters + if next(iter(mapping)) not in param_names: + raise base.InvalidDecoratorException( + f"Function: {fn.__name__} with parameters {param_names} does not a have " + f"dependency {next(iter(mapping))}. @pipe_input requires the parameter " + f"names to match the function parameters. Thus it might not be compatible " + f"with some other decorators." + ) + return mapping def _resolve_namespace( @@ -979,7 +1015,7 @@ def inject_nodes( then reassigns the inputs to pass it in.""" parameters_transforms_mapping = self._distribute_transforms_to_parameters(params=params) - parameters_transforms_mapping = self._check_parameters_transforms_mapping_validity( + parameters_transforms_mapping = self._create_valid_parameters_transforms_mapping( mapping=parameters_transforms_mapping, fn=fn, params=params ) diff --git a/tests/function_modifiers/test_macros.py b/tests/function_modifiers/test_macros.py index 59663897d..27f96b504 100644 --- a/tests/function_modifiers/test_macros.py +++ b/tests/function_modifiers/test_macros.py @@ -330,129 +330,20 @@ def function_multiple_same_type_params(p1: int, p2: int, p3: int) -> int: return p1 + p2 + p3 -def function_multiple_diverse_type_params(p1: int, p2: str, p3: int) -> int: - return p1 + len(p2) + p3 - - -def test_pipe_input_no_namespace_with_target(): - n = node.Node.from_fn(function_multiple_diverse_type_params) - - decorator = pipe_input( - step(_test_apply_function, source("bar_upstream"), baz=value(10)) - .on_input(["p1", "p3"]) - .named("node_1"), - step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), - step(_test_apply_function, source("bar_upstream"), baz=value(1000)) - .on_input("p3") - .named("node_3"), - on_input="p2", - namespace=None, - ) - nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) - final_node = nodes[0].name - p1_node = nodes[1].name - p2_node1 = nodes[2].name - p2_node2 = nodes[3].name - p2_node3 = nodes[4].name - p3_node1 = nodes[5].name - p3_node2 = nodes[6].name - - assert final_node == "function_multiple_diverse_type_params" - assert p1_node == "p1.node_1" - assert p2_node1 == "p2.node_1" - assert p2_node2 == "p2.node_2" - assert p2_node3 == "p2.node_3" - assert p3_node1 == "p3.node_1" - assert p3_node2 == "p3.node_3" - - -def test_pipe_input_elipsis_namespace_with_target(): - n = node.Node.from_fn(function_multiple_diverse_type_params) - - decorator = pipe_input( - step(_test_apply_function, source("bar_upstream"), baz=value(10)) - .on_input(["p1", "p3"]) - .named("node_1"), - step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), - step(_test_apply_function, source("bar_upstream"), baz=value(1000)) - .on_input("p3") - .named("node_3"), - namespace=..., - on_input="p2", - ) - nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) - final_node = nodes[0].name - p1_node = nodes[1].name - p2_node1 = nodes[2].name - p2_node2 = nodes[3].name - p2_node3 = nodes[4].name - p3_node1 = nodes[5].name - p3_node2 = nodes[6].name - - assert final_node == "function_multiple_diverse_type_params" - assert p1_node == "p1.node_1" - assert p2_node1 == "p2.node_1" - assert p2_node2 == "p2.node_2" - assert p2_node3 == "p2.node_3" - assert p3_node1 == "p3.node_1" - assert p3_node2 == "p3.node_3" - - -def test_pipe_input_custom_namespace_with_target(): - n = node.Node.from_fn(function_multiple_diverse_type_params) - - decorator = pipe_input( - step(_test_apply_function, source("bar_upstream"), baz=value(10)) - .on_input(["p1", "p3"]) - .named("node_1"), - step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), - step(_test_apply_function, source("bar_upstream"), baz=value(1000)) - .on_input("p3") - .named("node_3"), - namespace="abc", - on_input="p2", - ) - nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) - final_node = nodes[0].name - p1_node = nodes[1].name - p2_node1 = nodes[2].name - p2_node2 = nodes[3].name - p2_node3 = nodes[4].name - p3_node1 = nodes[5].name - p3_node2 = nodes[6].name - - assert final_node == "function_multiple_diverse_type_params" - assert p1_node == "abc_p1.node_1" - assert p2_node1 == "abc_p2.node_1" - assert p2_node2 == "abc_p2.node_2" - assert p2_node3 == "abc_p2.node_3" - assert p3_node1 == "abc_p3.node_1" - assert p3_node2 == "abc_p3.node_3" - - -def test_pipe_input_mapping_args_targets_local(): - n = node.Node.from_fn(function_multiple_diverse_type_params) - - decorator = pipe_input( - step(_test_apply_function, source("bar_upstream"), baz=value(10)) - .on_input(["p1", "p3"]) - .named("node_1"), - step(_test_apply_function, source("bar_upstream"), baz=value(100)) - .on_input("p2") - .named("node_2"), - step(_test_apply_function, source("bar_upstream"), baz=value(1000)) - .on_input("p3") - .named("node_3"), - namespace="abc", - ) - nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) - nodes_by_name = {item.name: item for item in nodes} - chain_node_1 = nodes_by_name["abc_p1.node_1"] - chain_node_2 = nodes_by_name["abc_p2.node_2"] - chain_node_3_first = nodes_by_name["abc_p3.node_1"] - assert chain_node_1(p1=1, bar_upstream=3) == 14 - assert chain_node_2(p2=1, bar_upstream=3) == 104 - assert chain_node_3_first(p3=7, bar_upstream=3) == 20 +# TODO: in case of multiple paramters need some type checking +# def function_multiple_diverse_type_params(p1: int, p2: str, p3: int) -> int: +# return p1 + len(p2) + p3 + + +def test_pipe_input_on_input_error_unless_string_or_none(): + with pytest.raises(NotImplementedError): + decorator = pipe_input( # noqa + step(_test_apply_function, source("bar_upstream"), baz=value(10)).named("node_1"), + step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), + step(_test_apply_function, source("bar_upstream"), baz=value(1000)).named("node_3"), + on_input=["p2", "p3"], + namespace="abc", + ) def test_pipe_input_mapping_args_targets_global(): @@ -465,88 +356,210 @@ def test_pipe_input_mapping_args_targets_global(): on_input="p2", namespace="abc", ) - nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) + nodes = decorator.transform_dag([n], {}, function_multiple_same_type_params) nodes_by_name = {item.name: item for item in nodes} chain_node = nodes_by_name["abc.node_1"] assert chain_node(p2=1, bar_upstream=3) == 14 -def test_pipe_input_mapping_args_targets_local_adds_to_global(): - n = node.Node.from_fn(function_multiple_same_type_params) - - decorator = pipe_input( - step(_test_apply_function, source("bar_upstream"), baz=value(10)) - .on_input(["p1", "p2"]) - .named("node_1"), - step(_test_apply_function, source("bar_upstream"), baz=value(100)) - .on_input("p2") - .named("node_2"), - step(_test_apply_function, source("bar_upstream"), baz=value(1000)).named("node_3"), - on_input="p3", - namespace="abc", - ) - nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) - nodes_by_name = {item.name: item for item in nodes} - p1_node = nodes_by_name["abc_p1.node_1"] - p2_node1 = nodes_by_name["abc_p2.node_1"] - p2_node2 = nodes_by_name["abc_p2.node_2"] - p3_node1 = nodes_by_name["abc_p3.node_1"] - p3_node2 = nodes_by_name["abc_p3.node_2"] - p3_node3 = nodes_by_name["abc_p3.node_3"] - - assert p1_node(p1=1, bar_upstream=3) == 14 - assert p2_node1(p2=7, bar_upstream=3) == 20 - assert p2_node2(**{"abc_p2.node_1": 2, "bar_upstream": 3}) == 105 - assert p3_node1(p3=9, bar_upstream=3) == 22 - assert p3_node2(**{"abc_p3.node_1": 13, "bar_upstream": 3}) == 116 - assert p3_node3(**{"abc_p3.node_2": 17, "bar_upstream": 3}) == 1020 - - -def test_pipe_input_fails_with_missing_targets(): - n = node.Node.from_fn(function_multiple_same_type_params) - - decorator = pipe_input( - step(_test_apply_function, source("bar_upstream"), baz=value(10)) - .on_input(["p1", "p2"]) - .named("node_1"), - step(_test_apply_function, source("bar_upstream"), baz=value(100)) - .on_input("p2") - .named("node_2"), - step(_test_apply_function, source("bar_upstream"), baz=value(1000)).named("node_3"), - namespace="abc", - ) - with pytest.raises(hamilton.function_modifiers.macros.MissingTargetError): - nodes = decorator.transform_dag([n], {}, function_multiple_same_type_params) # noqa - - -def test_pipe_input_decorator_with_target_no_collapse_multi_node(): - n = node.Node.from_fn(function_multiple_same_type_params) - - decorator = pipe_input( - step(_test_apply_function, source("bar_upstream"), baz=value(10)) - .on_input(["p1", "p3"]) - .named("node_1"), - step(_test_apply_function, source("bar_upstream"), baz=value(100)) - .on_input("p2") - .named("node_2"), - step(_test_apply_function, source("bar_upstream"), baz=value(1000)) - .on_input("p3") - .named("node_3"), - namespace="abc", - ) - nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) - nodes_by_name = {item.name: item for item in nodes} - final_node = nodes_by_name["function_multiple_same_type_params"] - chain_node_1 = nodes_by_name["abc_p1.node_1"] - chain_node_2 = nodes_by_name["abc_p2.node_2"] - chain_node_3_first = nodes_by_name["abc_p3.node_1"] - chain_node_3_second = nodes_by_name["abc_p3.node_3"] - assert len(nodes_by_name) == 5 - assert chain_node_1(p1=1, bar_upstream=3) == 14 - assert chain_node_2(p2=1, bar_upstream=3) == 104 - assert chain_node_3_first(p3=7, bar_upstream=3) == 20 - assert chain_node_3_second(**{"abc_p3.node_1": 13, "bar_upstream": 3}) == 1016 - assert final_node(**{"abc_p1.node_1": 3, "abc_p2.node_2": 4, "abc_p3.node_3": 5}) == 12 +# TODO: multiple parameter tests +# def test_pipe_input_no_namespace_with_target(): +# n = node.Node.from_fn(function_multiple_diverse_type_params) + +# decorator = pipe_input( +# step(_test_apply_function, source("bar_upstream"), baz=value(10)) +# .on_input(["p1", "p3"]) +# .named("node_1"), +# step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), +# step(_test_apply_function, source("bar_upstream"), baz=value(1000)) +# .on_input("p3") +# .named("node_3"), +# on_input="p2", +# namespace=None, +# ) +# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) +# final_node = nodes[0].name +# p1_node = nodes[1].name +# p2_node1 = nodes[2].name +# p2_node2 = nodes[3].name +# p2_node3 = nodes[4].name +# p3_node1 = nodes[5].name +# p3_node2 = nodes[6].name + +# assert final_node == "function_multiple_diverse_type_params" +# assert p1_node == "p1.node_1" +# assert p2_node1 == "p2.node_1" +# assert p2_node2 == "p2.node_2" +# assert p2_node3 == "p2.node_3" +# assert p3_node1 == "p3.node_1" +# assert p3_node2 == "p3.node_3" + + +# def test_pipe_input_elipsis_namespace_with_target(): +# n = node.Node.from_fn(function_multiple_diverse_type_params) + +# decorator = pipe_input( +# step(_test_apply_function, source("bar_upstream"), baz=value(10)) +# .on_input(["p1", "p3"]) +# .named("node_1"), +# step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), +# step(_test_apply_function, source("bar_upstream"), baz=value(1000)) +# .on_input("p3") +# .named("node_3"), +# namespace=..., +# on_input="p2", +# ) +# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) +# final_node = nodes[0].name +# p1_node = nodes[1].name +# p2_node1 = nodes[2].name +# p2_node2 = nodes[3].name +# p2_node3 = nodes[4].name +# p3_node1 = nodes[5].name +# p3_node2 = nodes[6].name + +# assert final_node == "function_multiple_diverse_type_params" +# assert p1_node == "p1.node_1" +# assert p2_node1 == "p2.node_1" +# assert p2_node2 == "p2.node_2" +# assert p2_node3 == "p2.node_3" +# assert p3_node1 == "p3.node_1" +# assert p3_node2 == "p3.node_3" + + +# def test_pipe_input_custom_namespace_with_target(): +# n = node.Node.from_fn(function_multiple_diverse_type_params) + +# decorator = pipe_input( +# step(_test_apply_function, source("bar_upstream"), baz=value(10)) +# .on_input(["p1", "p3"]) +# .named("node_1"), +# step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), +# step(_test_apply_function, source("bar_upstream"), baz=value(1000)) +# .on_input("p3") +# .named("node_3"), +# namespace="abc", +# on_input="p2", +# ) +# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) +# final_node = nodes[0].name +# p1_node = nodes[1].name +# p2_node1 = nodes[2].name +# p2_node2 = nodes[3].name +# p2_node3 = nodes[4].name +# p3_node1 = nodes[5].name +# p3_node2 = nodes[6].name + +# assert final_node == "function_multiple_diverse_type_params" +# assert p1_node == "abc_p1.node_1" +# assert p2_node1 == "abc_p2.node_1" +# assert p2_node2 == "abc_p2.node_2" +# assert p2_node3 == "abc_p2.node_3" +# assert p3_node1 == "abc_p3.node_1" +# assert p3_node2 == "abc_p3.node_3" + + +# def test_pipe_input_mapping_args_targets_local(): +# n = node.Node.from_fn(function_multiple_diverse_type_params) + +# decorator = pipe_input( +# step(_test_apply_function, source("bar_upstream"), baz=value(10)) +# .on_input(["p1", "p3"]) +# .named("node_1"), +# step(_test_apply_function, source("bar_upstream"), baz=value(100)) +# .on_input("p2") +# .named("node_2"), +# step(_test_apply_function, source("bar_upstream"), baz=value(1000)) +# .on_input("p3") +# .named("node_3"), +# namespace="abc", +# ) +# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) +# nodes_by_name = {item.name: item for item in nodes} +# chain_node_1 = nodes_by_name["abc_p1.node_1"] +# chain_node_2 = nodes_by_name["abc_p2.node_2"] +# chain_node_3_first = nodes_by_name["abc_p3.node_1"] +# assert chain_node_1(p1=1, bar_upstream=3) == 14 +# assert chain_node_2(p2=1, bar_upstream=3) == 104 +# assert chain_node_3_first(p3=7, bar_upstream=3) == 20 +# +# +# def test_pipe_input_mapping_args_targets_local_adds_to_global(): +# n = node.Node.from_fn(function_multiple_same_type_params) + +# decorator = pipe_input( +# step(_test_apply_function, source("bar_upstream"), baz=value(10)) +# .on_input(["p1", "p2"]) +# .named("node_1"), +# step(_test_apply_function, source("bar_upstream"), baz=value(100)) +# .on_input("p2") +# .named("node_2"), +# step(_test_apply_function, source("bar_upstream"), baz=value(1000)).named("node_3"), +# on_input="p3", +# namespace="abc", +# ) +# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) +# nodes_by_name = {item.name: item for item in nodes} +# p1_node = nodes_by_name["abc_p1.node_1"] +# p2_node1 = nodes_by_name["abc_p2.node_1"] +# p2_node2 = nodes_by_name["abc_p2.node_2"] +# p3_node1 = nodes_by_name["abc_p3.node_1"] +# p3_node2 = nodes_by_name["abc_p3.node_2"] +# p3_node3 = nodes_by_name["abc_p3.node_3"] + +# assert p1_node(p1=1, bar_upstream=3) == 14 +# assert p2_node1(p2=7, bar_upstream=3) == 20 +# assert p2_node2(**{"abc_p2.node_1": 2, "bar_upstream": 3}) == 105 +# assert p3_node1(p3=9, bar_upstream=3) == 22 +# assert p3_node2(**{"abc_p3.node_1": 13, "bar_upstream": 3}) == 116 +# assert p3_node3(**{"abc_p3.node_2": 17, "bar_upstream": 3}) == 1020 + + +# def test_pipe_input_fails_with_missing_targets(): +# n = node.Node.from_fn(function_multiple_same_type_params) + +# decorator = pipe_input( +# step(_test_apply_function, source("bar_upstream"), baz=value(10)) +# .on_input(["p1", "p2"]) +# .named("node_1"), +# step(_test_apply_function, source("bar_upstream"), baz=value(100)) +# .on_input("p2") +# .named("node_2"), +# step(_test_apply_function, source("bar_upstream"), baz=value(1000)).named("node_3"), +# namespace="abc", +# ) +# with pytest.raises(hamilton.function_modifiers.macros.MissingTargetError): +# nodes = decorator.transform_dag([n], {}, function_multiple_same_type_params) # noqa + + +# def test_pipe_input_decorator_with_target_no_collapse_multi_node(): +# n = node.Node.from_fn(function_multiple_same_type_params) + +# decorator = pipe_input( +# step(_test_apply_function, source("bar_upstream"), baz=value(10)) +# .on_input(["p1", "p3"]) +# .named("node_1"), +# step(_test_apply_function, source("bar_upstream"), baz=value(100)) +# .on_input("p2") +# .named("node_2"), +# step(_test_apply_function, source("bar_upstream"), baz=value(1000)) +# .on_input("p3") +# .named("node_3"), +# namespace="abc", +# ) +# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) +# nodes_by_name = {item.name: item for item in nodes} +# final_node = nodes_by_name["function_multiple_same_type_params"] +# chain_node_1 = nodes_by_name["abc_p1.node_1"] +# chain_node_2 = nodes_by_name["abc_p2.node_2"] +# chain_node_3_first = nodes_by_name["abc_p3.node_1"] +# chain_node_3_second = nodes_by_name["abc_p3.node_3"] +# assert len(nodes_by_name) == 5 +# assert chain_node_1(p1=1, bar_upstream=3) == 14 +# assert chain_node_2(p2=1, bar_upstream=3) == 104 +# assert chain_node_3_first(p3=7, bar_upstream=3) == 20 +# assert chain_node_3_second(**{"abc_p3.node_1": 13, "bar_upstream": 3}) == 1016 +# assert final_node(**{"abc_p1.node_1": 3, "abc_p2.node_2": 4, "abc_p3.node_3": 5}) == 12 def test_pipe_decorator_positional_variable_args(): @@ -655,33 +668,6 @@ def test_pipe_end_to_end_1(): assert result["chain_2_using_pipe"] == result["chain_2_not_using_pipe"] -def test_pipe_end_to_end_target_local(): - dr = ( - driver.Builder() - .with_modules(tests.resources.pipe_input) - .with_adapter(base.DefaultAdapter()) - .with_config({"calc_c": True}) - .build() - ) - - inputs = { - "input_1": 10, - "input_2": 20, - "input_3": 30, - } - result = dr.execute( - [ - "chain_1_using_pipe_input_target_local", - "chain_1_not_using_pipe_input_target_local", - ], - inputs=inputs, - ) - assert ( - result["chain_1_not_using_pipe_input_target_local"] - == result["chain_1_using_pipe_input_target_local"] - ) - - def test_pipe_end_to_end_target_global(): dr = ( driver.Builder() @@ -709,31 +695,58 @@ def test_pipe_end_to_end_target_global(): ) -def test_pipe_end_to_end_target_mixed(): - dr = ( - driver.Builder() - .with_modules(tests.resources.pipe_input) - .with_adapter(base.DefaultAdapter()) - .with_config({"calc_c": True}) - .build() - ) - - inputs = { - "input_1": 10, - "input_2": 20, - "input_3": 30, - } - result = dr.execute( - [ - "chain_1_using_pipe_input_target_mixed", - "chain_1_not_using_pipe_input_target_mixed", - ], - inputs=inputs, - ) - assert ( - result["chain_1_not_using_pipe_input_target_mixed"] - == result["chain_1_using_pipe_input_target_mixed"] - ) +# TODO: For multiple parameters end-to-end +# def test_pipe_end_to_end_target_local(): +# dr = ( +# driver.Builder() +# .with_modules(tests.resources.pipe_input) +# .with_adapter(base.DefaultAdapter()) +# .with_config({"calc_c": True}) +# .build() +# ) + +# inputs = { +# "input_1": 10, +# "input_2": 20, +# "input_3": 30, +# } +# result = dr.execute( +# [ +# "chain_1_using_pipe_input_target_local", +# "chain_1_not_using_pipe_input_target_local", +# ], +# inputs=inputs, +# ) +# assert ( +# result["chain_1_not_using_pipe_input_target_local"] +# == result["chain_1_using_pipe_input_target_local"] +# ) + +# def test_pipe_end_to_end_target_mixed(): +# dr = ( +# driver.Builder() +# .with_modules(tests.resources.pipe_input) +# .with_adapter(base.DefaultAdapter()) +# .with_config({"calc_c": True}) +# .build() +# ) + +# inputs = { +# "input_1": 10, +# "input_2": 20, +# "input_3": 30, +# } +# result = dr.execute( +# [ +# "chain_1_using_pipe_input_target_mixed", +# "chain_1_not_using_pipe_input_target_mixed", +# ], +# inputs=inputs, +# ) +# assert ( +# result["chain_1_not_using_pipe_input_target_mixed"] +# == result["chain_1_using_pipe_input_target_mixed"] +# ) def result_from_downstream_function() -> int: diff --git a/tests/resources/pipe_input.py b/tests/resources/pipe_input.py index 16ac2c0ae..0902ce49d 100644 --- a/tests/resources/pipe_input.py +++ b/tests/resources/pipe_input.py @@ -93,52 +93,53 @@ def chain_1_not_using_pipe_input_target_global( return v + e * 10 -@pipe_input( - step(_add_one).on_input(["v", "w"]).named("a"), - step(_add_two).on_input("v").named("b"), - step(_add_n, n=3).on_input("w").named("c").when(calc_c=True), - step(_add_n, n=source("input_1")).on_input("v").named("d"), - step(_multiply_n, n=source("input_2")).on_input("w").named("e"), - namespace="local", -) -def chain_1_using_pipe_input_target_local(v: int, w: int) -> int: - return v + w * 10 - - -def chain_1_not_using_pipe_input_target_local( - v: int, w: int, input_1: int, input_2: int, calc_c: bool = False -) -> int: - av = _add_one(v) - aw = _add_one(w) - bv = _add_two(av) - cw = _add_n(aw, n=3) if calc_c else aw - dv = _add_n(bv, n=input_1) - ew = _multiply_n(cw, n=input_2) - return dv + ew * 10 - - -@pipe_input( - step(_add_one).on_input("w").named("a"), - step(_add_two).named("b"), - step(_add_n, n=3).on_input("w").named("c").when(calc_c=True), - step(_add_n, n=source("input_1")).named("d"), - step(_multiply_n, n=source("input_2")).on_input("w").named("e"), - namespace="mixed", - on_input="v", -) -def chain_1_using_pipe_input_target_mixed(v: int, w: int) -> int: - return v + w * 10 - - -def chain_1_not_using_pipe_input_target_mixed( - v: int, w: int, input_1: int, input_2: int, calc_c: bool = False -) -> int: - av = _add_one(v) - aw = _add_one(w) - bv = _add_two(av) - cv = _add_n(bv, n=3) if calc_c else bv - cw = _add_n(aw, n=3) if calc_c else aw - dv = _add_n(cv, n=input_1) - ev = _multiply_n(dv, n=input_2) - ew = _multiply_n(cw, n=input_2) - return ev + ew * 10 +# TODO: for tests in case of multiple target parameters +# @pipe_input( +# step(_add_one).on_input(["v", "w"]).named("a"), +# step(_add_two).on_input("v").named("b"), +# step(_add_n, n=3).on_input("w").named("c").when(calc_c=True), +# step(_add_n, n=source("input_1")).on_input("v").named("d"), +# step(_multiply_n, n=source("input_2")).on_input("w").named("e"), +# namespace="local", +# ) +# def chain_1_using_pipe_input_target_local(v: int, w: int) -> int: +# return v + w * 10 + + +# def chain_1_not_using_pipe_input_target_local( +# v: int, w: int, input_1: int, input_2: int, calc_c: bool = False +# ) -> int: +# av = _add_one(v) +# aw = _add_one(w) +# bv = _add_two(av) +# cw = _add_n(aw, n=3) if calc_c else aw +# dv = _add_n(bv, n=input_1) +# ew = _multiply_n(cw, n=input_2) +# return dv + ew * 10 + + +# @pipe_input( +# step(_add_one).on_input("w").named("a"), +# step(_add_two).named("b"), +# step(_add_n, n=3).on_input("w").named("c").when(calc_c=True), +# step(_add_n, n=source("input_1")).named("d"), +# step(_multiply_n, n=source("input_2")).on_input("w").named("e"), +# namespace="mixed", +# on_input="v", +# ) +# def chain_1_using_pipe_input_target_mixed(v: int, w: int) -> int: +# return v + w * 10 + + +# def chain_1_not_using_pipe_input_target_mixed( +# v: int, w: int, input_1: int, input_2: int, calc_c: bool = False +# ) -> int: +# av = _add_one(v) +# aw = _add_one(w) +# bv = _add_two(av) +# cv = _add_n(bv, n=3) if calc_c else bv +# cw = _add_n(aw, n=3) if calc_c else aw +# dv = _add_n(cv, n=input_1) +# ev = _multiply_n(dv, n=input_2) +# ew = _multiply_n(cw, n=input_2) +# return ev + ew * 10