From 5ce3b76f54e928c6c0f386cf5262478e2d42970b Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Sun, 18 Aug 2024 15:07:17 -0700 Subject: [PATCH 1/2] Adds setstate getstate to driver and fixes 1093 This fixes #1093. Modules were not picklable. So this fixes that by serializing their fully qualified names, and then when the driver object is deserialized, they are reinstantiated as module objects. --- hamilton/driver.py | 19 ++++++++++++ tests/resources/test_driver_serde_mapper.py | 33 +++++++++++++++++++++ tests/resources/test_driver_serde_worker.py | 2 ++ tests/test_hamilton_driver.py | 28 +++++++++++++++++ 4 files changed, 82 insertions(+) create mode 100644 tests/resources/test_driver_serde_mapper.py create mode 100644 tests/resources/test_driver_serde_worker.py diff --git a/hamilton/driver.py b/hamilton/driver.py index 67799bb37..436350641 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -1,5 +1,7 @@ import abc import functools +import importlib +import importlib.util import json import logging import operator @@ -260,6 +262,23 @@ class Driver: dr = driver.Driver(config, module, adapter=adapter) """ + def __getstate__(self): + # Copy the object's state from self.__dict__ + state = self.__dict__.copy() + # Remove the unpicklable entries -- right now it's the modules tracked. + state["graph_module_names"] = [ + importlib.util.find_spec(m.__name__).name for m in state["graph_modules"] + ] + del state["graph_modules"] # remove from state + return state + + def __setstate__(self, state): + # Restore instance attributes + self.__dict__.update(state) + # Reinitialize the unpicklable entries + # assumption is that the modules are importable in the new process + self.graph_modules = [importlib.import_module(n) for n in state["graph_module_names"]] + @staticmethod def normalize_adapter_input( adapter: Optional[ diff --git a/tests/resources/test_driver_serde_mapper.py b/tests/resources/test_driver_serde_mapper.py new file mode 100644 index 000000000..4ddb9f392 --- /dev/null +++ b/tests/resources/test_driver_serde_mapper.py @@ -0,0 +1,33 @@ +from typing import Any + +from hamilton.htypes import Collect, Parallelizable + + +def mapper( + drivers: list, + inputs: list, + final_vars: list = None, +) -> Parallelizable[dict]: + if final_vars is None: + final_vars = [] + for dr, input_ in zip(drivers, inputs): + yield { + "dr": dr, + "final_vars": final_vars or dr.list_available_variables(), + "input": input_, + } + + +def inside(mapper: dict) -> dict: + _dr = mapper["dr"] + _inputs = mapper["input"] + _final_var = mapper["final_vars"] + return _dr.execute(final_vars=_final_var, inputs=_inputs) + + +def passthrough(inside: dict) -> dict: + return inside + + +def reducer(passthrough: Collect[dict]) -> Any: + return passthrough diff --git a/tests/resources/test_driver_serde_worker.py b/tests/resources/test_driver_serde_worker.py new file mode 100644 index 000000000..9b4b421a1 --- /dev/null +++ b/tests/resources/test_driver_serde_worker.py @@ -0,0 +1,2 @@ +def double(a: int) -> int: + return a * 2 diff --git a/tests/test_hamilton_driver.py b/tests/test_hamilton_driver.py index 54443fd1c..0504cafe1 100644 --- a/tests/test_hamilton_driver.py +++ b/tests/test_hamilton_driver.py @@ -19,6 +19,8 @@ import tests.resources.dynamic_parallelism.parallel_linear_basic import tests.resources.tagging import tests.resources.test_default_args +import tests.resources.test_driver_serde_mapper +import tests.resources.test_driver_serde_worker import tests.resources.test_for_materialization import tests.resources.very_simple_dag @@ -665,3 +667,29 @@ def func_to_test(a: int) -> int: assert v.tags == n.tags assert v.documentation == n.documentation == "This is a doctstring" assert v.originating_functions == n.originating_functions + + +def test_driver_setstate_getstate(): + """This is an integration test testing serializability of the hamilton driver.""" + from hamilton.execution import executors + + drivers = [] + inputs = [] + for i in range(4): + dr = Builder().with_modules(tests.resources.test_driver_serde_worker).build() + drivers.append(dr) + inputs.append({"a": i}) + + dr = ( + Builder() + .with_modules(tests.resources.test_driver_serde_mapper) + .enable_dynamic_execution(allow_experimental_mode=True) + # .with_local_executor(executors.SynchronousLocalTaskExecutor()) + .with_remote_executor(executors.MultiProcessingExecutor(8)) + .build() + ) + r = dr.execute( + final_vars=["reducer"], + inputs={"drivers": drivers, "inputs": inputs, "final_vars": ["double"]}, + ) + assert r == {"reducer": [{"double": 0}, {"double": 2}, {"double": 4}, {"double": 6}]} From d8ec5c72347d5ab2e3416006d68520a67fc3bbc3 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Mon, 19 Aug 2024 15:17:05 -0700 Subject: [PATCH 2/2] PR feedback --- hamilton/driver.py | 14 ++++++++++++-- tests/test_hamilton_driver.py | 3 +-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/hamilton/driver.py b/hamilton/driver.py index 436350641..a2ba558d5 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -263,21 +263,31 @@ class Driver: """ def __getstate__(self): + """Used for serialization.""" # Copy the object's state from self.__dict__ state = self.__dict__.copy() # Remove the unpicklable entries -- right now it's the modules tracked. - state["graph_module_names"] = [ + state["__graph_module_names"] = [ importlib.util.find_spec(m.__name__).name for m in state["graph_modules"] ] del state["graph_modules"] # remove from state return state def __setstate__(self, state): + """Used for deserialization.""" # Restore instance attributes self.__dict__.update(state) # Reinitialize the unpicklable entries # assumption is that the modules are importable in the new process - self.graph_modules = [importlib.import_module(n) for n in state["graph_module_names"]] + self.graph_modules = [] + for n in state["__graph_module_names"]: + try: + g_module = importlib.import_module(n) + except ImportError: + logger.error(f"Could not import module {n}") + continue + else: + self.graph_modules.append(g_module) @staticmethod def normalize_adapter_input( diff --git a/tests/test_hamilton_driver.py b/tests/test_hamilton_driver.py index 0504cafe1..23fe8ac71 100644 --- a/tests/test_hamilton_driver.py +++ b/tests/test_hamilton_driver.py @@ -684,8 +684,7 @@ def test_driver_setstate_getstate(): Builder() .with_modules(tests.resources.test_driver_serde_mapper) .enable_dynamic_execution(allow_experimental_mode=True) - # .with_local_executor(executors.SynchronousLocalTaskExecutor()) - .with_remote_executor(executors.MultiProcessingExecutor(8)) + .with_remote_executor(executors.MultiProcessingExecutor(4)) .build() ) r = dr.execute(