Skip to content

Commit

Permalink
Restrict pipe_input API to single on_input target
Browse files Browse the repository at this point in the history
The implementation allows for global string input for a target.

- Functionality is there for multiple global and local parameter targets, but
  we shortcircuit for the moment with an error if the arg is not a
  string.
- At the moment all transforms in the pipe get applied to the same
  argument.
- Test and docs have commented out the multiple parameter cases for easy
  enabling should we want this.
- Added URL of the GitHub issue tracking interest to enable multiple
  params
  • Loading branch information
jernejfrank committed Oct 12, 2024
1 parent afcdb96 commit 5188e08
Show file tree
Hide file tree
Showing 3 changed files with 446 additions and 396 deletions.
226 changes: 131 additions & 95 deletions hamilton/function_modifiers/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def __init__(
:param _name: Name of the node to be created
:param _namespace: Namespace of the node to be created -- currently only single-level namespaces are supported
:param _target: Selects which target nodes it will be appended onto. Default None gets resolved on decorator level.
Specifically, pipe_input would use the first parameter and pipe_output / mutate would apply it to all sink nodes.
:param kwargs: Kwargs (**kwargs) to pass to the function
"""

Expand Down Expand Up @@ -475,25 +476,26 @@ def named(self, name: str, namespace: NamespaceType = ...) -> "Applicable":
# on_input / on_output are the same but here for naming convention
# I know there is a way to dynamically resolve this to revert to a common function
# just can't remember it now or find it online...
def on_input(self, target: base.TargetType) -> "Applicable":
"""Add Target on a single function level.
This determines to which node(s) it will applies. Should match the same naming convention
as the NodeTransorfmLifecycle child class (for example NodeTransformer).
:param target: Which node(s) to apply on top of
:return: The Applicable with specified target
"""
return Applicable(
fn=self.fn,
_resolvers=self.resolvers,
_name=self.name,
_namespace=self.namespace,
_target=target if target is not None else self.target,
args=self.args,
kwargs=self.kwargs,
target_fn=self.target_fn,
)
# TODO: adding the option to select target parameter for each transform
# def on_input(self, target: base.TargetType) -> "Applicable":
# """Add Target on a single function level.

# This determines to which node(s) it will applies. Should match the same naming convention
# as the NodeTransorfmLifecycle child class (for example NodeTransformer).

# :param target: Which node(s) to apply on top of
# :return: The Applicable with specified target
# """
# return Applicable(
# fn=self.fn,
# _resolvers=self.resolvers,
# _name=self.name,
# _namespace=self.namespace,
# _target=target if target is not None else self.target,
# args=self.args,
# kwargs=self.kwargs,
# target_fn=self.target_fn,
# )

def on_output(self, target: base.TargetType) -> "Applicable":
"""Add Target on a single function level.
Expand Down Expand Up @@ -649,18 +651,19 @@ def step(
return Applicable(fn=fn, _resolvers=[], args=args, kwargs=kwargs)


class MissingTargetError(Exception):
"""When setting target make sure it is clear which transform targets which node.
# TODO: In case of multiple parameter targets we want to safeguard that it is a clear target distribution
# class MissingTargetError(Exception):
# """When setting target make sure it is clear which transform targets which node.

This is a safeguard, because the default behavior may not apply if targets are partially set
and we do not want to make assumptions what the user meant.
"""
# This is a safeguard, because the default behavior may not apply if targets are partially set
# and we do not want to make assumptions what the user meant.
# """

pass
# pass


class pipe_input(base.NodeInjector):
"""Running a series of transformation on the input of the function.
"""Running a series of transformations on the input of the function.
To demonstrate the rules for chaining nodes, we'll be using the following example. This is
using primitives to demonstrate, but as hamilton is just functions of any python objects, this works perfectly with
Expand Down Expand Up @@ -737,9 +740,9 @@ def final_result(upstream_int: int) -> int:
One has three ways to tune the shape/implementation of the subsequent nodes:
1. ``when``/``when_not``/``when_in``/``when_not_in`` -- these are used to filter the application of the function. This is valuable to reflect
if/else conditions in the structure of the DAG, pulling it out of functions, rather than buried within the logic itself. It is functionally
equivalent to ``@config.when``.
1. ``when``/``when_not``/``when_in``/``when_not_in`` -- these are used to filter the application of the function.
This is valuable to reflect if/else conditions in the structure of the DAG, pulling it out of functions, rather
than buried within the logic itself. It is functionally equivalent to ``@config.when``.
For instance, if you want to include a function in the chain only when a config parameter is set to a certain value, you can do:
Expand All @@ -754,8 +757,8 @@ def final_result(upstream_int: int) -> int:
This will only apply the first function when the config parameter ``foo`` is set to ``bar``, and the second when it is set to ``baz``.
2. ``named`` -- this is used to name the node. This is useful if you want to refer to intermediate results. If this is left out,
hamilton will automatically name the functions in a globally unique manner. The names of
2. ``named`` -- this is used to name the node. This is useful if you want to refer to intermediate results.
If this is left out, hamilton will automatically name the functions in a globally unique manner. The names of
these functions will not necessarily be stable/guaranteed by the API, so if you want to refer to them, you should use ``named``.
The default namespace will always be the name of the decorated function (which will be the last node in the chain).
Expand Down Expand Up @@ -811,23 +814,11 @@ def final_result(upstream_int: int) -> int:
In all likelihood, you should not be using this, and this is only here in case you want to expose a node for
consumption/output later. Setting the namespace in individual nodes as well as in ``pipe_input`` is not yet supported.
3. For extra control in case of multiple function arguments (parameters), we can also specify the target parameter that we wish to transform.
In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter only. If ``on_input`` is set for a specific transform
make sure the other ones are also set either through a global setting or individually, otherwise it is unclear which transforms target which parameters.
The following applies *_add_one* to ``p1``, ``p3`` and *_add_two* to ``p2``
.. code-block:: python
@pipe_input(
step(_add_one).on_input(["p1","p3"])
step(_add_two, y=source("upstream_node")).on_input("p2")
)
def final_result(p1: int, p2: int, p3: int) -> int:
return upstream_int
3. ``on_input`` -- this selects which input we will run the pipeline on.
In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter. Let us know if you wish to expand to other use-cases.
You can track the progress on this topic via: https:/DAGWorks-Inc/hamilton/issues/1177
We can also do this on the global level to set for all transforms a target parameter.
The following would apply function *_add_one* and *_add_two* to ``p2``
The following would apply function *_add_one* and *_add_two* to ``p2``:
.. code-block:: python
Expand All @@ -839,20 +830,39 @@ def final_result(p1: int, p2: int, p3: int) -> int:
def final_result(p1: int, p2: int, p3: int) -> int:
return upstream_int
Lastly, a mixture of global and local is possible, where the global selects the target parameters for
all transforms and we can select individual transforms to also target more parameters.
The following would apply function *_add_one* to all ``p1``, ``p2``, ``p3`` and *_add_two* also on ``p2``
.. |
THIS IS COMMENTED OUT, I.E. SPHINX WILL NOT AUTODOC IT, HERE IN CASE WE ENABLE MULTIPLE PARAMETER TARGETS
For extra control in case of multiple function arguments (parameters), we can also specify the target parameter that we wish to transform.
In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter only. If ``on_input`` is set for a specific transform
make sure the other ones are also set either through a global setting or individually, otherwise it is unclear which transforms target which parameters.
.. code-block:: python
The following applies *_add_one* to ``p1``, ``p3`` and *_add_two* to ``p2``
@pipe_input(
step(_add_one).on_input(["p1","p3"])
step(_add_two, y=source("upstream_node")),
on_input = "p2"
)
def final_result(p1: int, p2: int, p3: int) -> int:
return upstream_int
.. code-block:: python
@pipe_input(
step(_add_one).on_input(["p1","p3"])
step(_add_two, y=source("upstream_node")).on_input("p2")
)
def final_result(p1: int, p2: int, p3: int) -> int:
return upstream_int
We can also do this on the global level to set for all transforms a target parameter.
Lastly, a mixture of global and local is possible, where the global selects the target parameters for
all transforms and we can select individual transforms to also target more parameters.
The following would apply function *_add_one* to all ``p1``, ``p2``, ``p3`` and *_add_two* also on ``p2``
.. code-block:: python
@pipe_input(
step(_add_one).on_input(["p1","p3"])
step(_add_two, y=source("upstream_node")),
on_input = "p2"
)
def final_result(p1: int, p2: int, p3: int) -> int:
return upstream_int
| replace:: \
"""

def __init__(
Expand All @@ -871,19 +881,27 @@ def __init__(
:param collapse: Whether to collapse this into a single node. This is not currently supported.
:param _chain: Whether to chain the first parameter. This is the only mode that is supported. Furthermore, this is not externally exposed. ``@flow`` will make use of this.
"""
if on_input is not None:
if not isinstance(on_input, str):
raise NotImplementedError(
"on_input currently only supports a single target parameter specified by a string. "
"Please reach out if you want a more flexible option in the feature."
)
base.NodeTransformer._early_validate_target(target=on_input, allow_multiple=True)

self.transforms = transforms
self.collapse = collapse
self.chain = _chain
self.namespace = namespace
self.target = [on_input]

if isinstance(on_input, str): # have to do extra since strings are collections in python
self.target = [on_input]
elif isinstance(on_input, Collection):
self.target = on_input
else:
self.target = [on_input]
# TODO: for multiple target parameter case
# if isinstance(on_input, str): # have to do extra since strings are collections in python
# self.target = [on_input]
# elif isinstance(on_input, Collection):
# self.target = on_input
# else:
# self.target = [on_input]

if self.collapse:
raise NotImplementedError(
Expand All @@ -909,54 +927,72 @@ def _distribute_transforms_to_parameters(

selected_transforms = defaultdict(list)
for param in params:
for transform in self.transforms:
target = transform.target
# In case there is no target set on applicable we assign global target
if target is None:
target = self.target
elif isinstance(target, str): # user selects single target via string
target = [target]
target.extend(self.target)
elif isinstance(target, Collection): # user inputs a list of targets
target.extend(self.target)

if param in target:
selected_transforms[param].append(transform)
if param in self.target:
selected_transforms[param].extend(self.transforms)
# TODO: in case of multiple parameters we can set individual targets and resolve them here
# for transform in self.transforms:
# target = transform.target
# # In case there is no target set on applicable we assign global target
# if target is None:
# target = self.target
# elif isinstance(target, str): # user selects single target via string
# target = [target]
# target.extend(self.target)
# elif isinstance(target, Collection): # user inputs a list of targets
# target.extend(self.target)

# if param in target:
# selected_transforms[param].append(transform)

return selected_transforms

def _check_parameters_transforms_mapping_validity(
def _create_valid_parameters_transforms_mapping(
self, mapping: Dict[str, List[Applicable]], fn: Callable, params: Dict[str, Type[Type]]
) -> Dict[str, List[Applicable]]:
"""Checks for a valid distribution of transforms to parameters."""
if not mapping:
# This reverts back to legacy chaining through first parameter and checks first parameter
sig = inspect.signature(fn)
first_parameter = list(sig.parameters.values())[0].name
sig = inspect.signature(fn)
param_names = []
for param in list(sig.parameters.values()):
param_names.append(param.name)
# use the name of the parameter to determine the first node
# Then wire them all through in order
# if it resolves, great
# if not, skip that, pointing to the previous
# Create a node along the way

if not mapping:
# This reverts back to legacy chaining through first parameter and checks first parameter
first_parameter = param_names[0]
if first_parameter not in params:
raise base.InvalidDecoratorException(
f"Function: {fn.__name__} has a first parameter that is not a dependency. "
f"@pipe requires the parameter names to match the function parameters. "
f"Thus it might not be compatible with some other decorators"
)
mapping[first_parameter] = self.transforms
else:
# in case we set target this checks that each transform has at least one target parameter
transform_set = []
for param in mapping:
transform_set.extend(mapping[param])
transform_set = set(transform_set)
if len(transform_set) != len(self.transforms):
raise MissingTargetError(
"The on_input settings are unclear. Please make sure all transforms "
"either have specified individually or globally a target or there is "
"no on_input usage."
)
# TODO: validate that all transforms have a target in case multiple parameters targeted
# else:
# # in case we set target this checks that each transform has at least one target parameter
# transform_set = []
# for param in mapping:
# transform_set.extend(mapping[param])
# transform_set = set(transform_set)
# if len(transform_set) != len(self.transforms):
# raise MissingTargetError(
# "The on_input settings are unclear. Please make sure all transforms "
# "either have specified individually or globally a target or there is "
# "no on_input usage."
# )

# similar to above we check that the target parameter is among the actual function parameters
if next(iter(mapping)) not in param_names:
raise base.InvalidDecoratorException(
f"Function: {fn.__name__} with parameters {param_names} does not a have "
f"dependency {next(iter(mapping))}. @pipe_input requires the parameter "
f"names to match the function parameters. Thus it might not be compatible "
f"with some other decorators."
)

return mapping

def _resolve_namespace(
Expand All @@ -979,7 +1015,7 @@ def inject_nodes(
then reassigns the inputs to pass it in."""

parameters_transforms_mapping = self._distribute_transforms_to_parameters(params=params)
parameters_transforms_mapping = self._check_parameters_transforms_mapping_validity(
parameters_transforms_mapping = self._create_valid_parameters_transforms_mapping(
mapping=parameters_transforms_mapping, fn=fn, params=params
)

Expand Down
Loading

0 comments on commit 5188e08

Please sign in to comment.