diff --git a/ui/sdk/src/hamilton_sdk/tracking/pyspark_stats.py b/ui/sdk/src/hamilton_sdk/tracking/pyspark_stats.py index 947b42525..32baaa406 100644 --- a/ui/sdk/src/hamilton_sdk/tracking/pyspark_stats.py +++ b/ui/sdk/src/hamilton_sdk/tracking/pyspark_stats.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional +import functools +from typing import Any, Dict, List, Optional import pyspark.sql as ps from hamilton_sdk.tracking import data_observation @@ -43,6 +44,8 @@ } +# quick cache to ensure we don't compute twice +@functools.lru_cache(maxsize=128) def _introspect(df: ps.DataFrame) -> Dict[str, Any]: """Introspect a PySpark dataframe and return a dictionary of statistics. @@ -105,6 +108,33 @@ def compute_schema_psdf( return None +@data_observation.compute_additional_results.register +def compute_additional_psdf( + result: ps.DataFrame, node_name: str, node_tags: dict +) -> List[ObservationType]: + o_value = _introspect(result) + return [ + { + "observability_type": "primitive", + "observability_value": { + "type": str(str), + "value": o_value["cost_explain"], + }, + "observability_schema_version": "0.0.1", + "name": "Cost Explain", + }, + { + "observability_type": "primitive", + "observability_value": { + "type": str(str), + "value": o_value["extended_explain"], + }, + "observability_schema_version": "0.0.1", + "name": "Extended Explain", + }, + ] + + if __name__ == "__main__": import numpy as np import pandas as pd