Skip to content

Commit

Permalink
fixed JSON export for 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
zilto authored and zilto committed Mar 4, 2024
1 parent 86736c0 commit 248f191
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions tests/test_driver_export.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json

from hamilton import driver
from hamilton import driver, graph_types

from tests.resources.dynamic_parallelism import no_parallel

# `version` attribute needs to be filled
EXPECTED_JSON = {
"nodes": [
{
Expand All @@ -18,7 +19,7 @@
" return sum_step_squared_plus_step_cubed\n"
),
"documentation": "",
"version": "27af6c2870dd4ca92335df4a9ff6c84da829d792a4c9f690f2f57193dbf3395a",
"version": None,
},
{
"name": "number_of_steps",
Expand All @@ -28,7 +29,7 @@
"optional_dependencies": [],
"source": "def number_of_steps() -> int:\n return 5\n",
"documentation": "",
"version": "48945782868c44bca1b159682d48bee3a1c1e9be73cf059c5087d0d3258fdc78",
"version": None,
},
{
"name": "step_cubed",
Expand All @@ -42,7 +43,7 @@
" return steps**3\n"
),
"documentation": "",
"version": "f82b894f3d8a8de9cf798a41e08c1f38c8396e86c569cb73b7d78b4420f9dd99",
"version": None,
},
{
"name": "step_squared",
Expand All @@ -53,7 +54,7 @@
"source": 'def step_squared(steps: int) -> int:\n print("squaring step {}".format(steps))\n '
" return steps**2\n",
"documentation": "",
"version": "371bff5e02e091314e0ef2c3ca3003a496a119f9164d3e1827fe64d5f8c4aaba",
"version": None,
},
{
"name": "step_squared_plus_step_cubed",
Expand All @@ -67,7 +68,7 @@
" return step_squared + step_cubed\n"
),
"documentation": "",
"version": "fe76eaaf43a6b79a706574ea6a8f00fe5b6439e3b8d7ceb879ab938b7722a1b3",
"version": None,
},
{
"name": "steps",
Expand All @@ -77,7 +78,7 @@
"optional_dependencies": [],
"source": "def steps(number_of_steps: int) -> int:\n return number_of_steps\n",
"documentation": "",
"version": "111ccfe2b5e8754e5ada1dae3ffc49d27b450a98e3be61f9308144aa5a444921",
"version": None,
},
{
"name": "sum_step_squared_plus_step_cubed",
Expand All @@ -91,13 +92,21 @@
" return step_squared_plus_step_cubed\n"
),
"documentation": "",
"version": "7831c9ed5f45e0b8ce3de34c3a5e2b4d4467b1ebc44cb55ea6ecae8f7dbf749d",
"version": None,
},
],
}


def test_export_execution():
dr = driver.Builder().with_modules(no_parallel).build()
graph = graph_types.HamiltonGraph.from_graph(dr.graph)

# can't hard code HamiltonNode.version because it depends on Python version
for n in graph.nodes:
for json_n in EXPECTED_JSON["nodes"]:
if json_n["name"] == n.name:
json_n["version"] = n.version

json_str = dr.export_execution(["final"])
assert json_str == json.dumps(EXPECTED_JSON)

0 comments on commit 248f191

Please sign in to comment.