Skip to content

Commit

Permalink
feat(sqlparser): support more update syntaxes + fix bug with subqueri…
Browse files Browse the repository at this point in the history
…es (#9105)
  • Loading branch information
hsheth2 authored Oct 30, 2023
1 parent 0bd2d9a commit ce0f36b
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 6 deletions.
57 changes: 52 additions & 5 deletions metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import sqlglot.errors
import sqlglot.lineage
import sqlglot.optimizer.annotate_types
import sqlglot.optimizer.optimizer
import sqlglot.optimizer.qualify
import sqlglot.optimizer.qualify_columns
from pydantic import BaseModel
from typing_extensions import TypedDict

Expand Down Expand Up @@ -48,6 +48,19 @@
SQL_PARSE_RESULT_CACHE_SIZE = 1000


RULES_BEFORE_TYPE_ANNOTATION: tuple = tuple(
filter(
# Skip pushdown_predicates because it sometimes throws exceptions, and we
# don't actually need it for anything.
lambda func: func.__name__ not in {"pushdown_predicates"},
itertools.takewhile(
lambda func: func != sqlglot.optimizer.annotate_types.annotate_types,
sqlglot.optimizer.optimizer.RULES,
),
)
)


class GraphQLSchemaField(TypedDict):
fieldPath: str
nativeDataType: str
Expand Down Expand Up @@ -289,6 +302,10 @@ def _table_level_lineage(
)
# TODO: If a CTAS has "LIMIT 0", it's not really lineage, just copying the schema.

# Update statements implicitly read from the table being updated, so add those back in.
if isinstance(statement, sqlglot.exp.Update):
tables = tables | modified

return tables, modified


Expand Down Expand Up @@ -568,17 +585,20 @@ def _schema_aware_fuzzy_column_resolve(
# - the select instead of the full outer statement
# - schema info
# - column qualification enabled
# - running the full pre-type annotation optimizer

# logger.debug("Schema: %s", sqlglot_db_schema.mapping)
statement = sqlglot.optimizer.qualify.qualify(
statement = sqlglot.optimizer.optimizer.optimize(
statement,
dialect=dialect,
schema=sqlglot_db_schema,
qualify_columns=True,
validate_qualify_columns=False,
identify=True,
# sqlglot calls the db -> schema -> table hierarchy "catalog", "db", "table".
catalog=default_db,
db=default_schema,
rules=RULES_BEFORE_TYPE_ANNOTATION,
)
except (sqlglot.errors.OptimizeError, ValueError) as e:
raise SqlUnderstandingError(
Expand Down Expand Up @@ -748,6 +768,7 @@ def _extract_select_from_create(
_UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT: Set[str] = set(
sqlglot.exp.Update.arg_types.keys()
) - set(sqlglot.exp.Select.arg_types.keys())
_UPDATE_FROM_TABLE_ARGS_TO_MOVE = {"joins", "laterals", "pivot"}


def _extract_select_from_update(
Expand All @@ -774,17 +795,43 @@ def _extract_select_from_update(
# they'll get caught later.
new_expressions.append(expr)

return sqlglot.exp.Select(
# Special translation for the `from` clause.
extra_args = {}
original_from = statement.args.get("from")
if original_from and isinstance(original_from.this, sqlglot.exp.Table):
# Move joins, laterals, and pivots from the Update->From->Table->field
# to the top-level Select->field.

for k in _UPDATE_FROM_TABLE_ARGS_TO_MOVE:
if k in original_from.this.args:
# Mutate the from table clause in-place.
extra_args[k] = original_from.this.args.pop(k)

select_statement = sqlglot.exp.Select(
**{
**{
k: v
for k, v in statement.args.items()
if k not in _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT
},
**extra_args,
"expressions": new_expressions,
}
)

# Update statements always implicitly have the updated table in context.
# TODO: Retain table name alias.
if select_statement.args.get("from"):
# select_statement = sqlglot.parse_one(select_statement.sql(dialect=dialect))

select_statement = select_statement.join(
statement.this, append=True, join_kind="cross"
)
else:
select_statement = select_statement.from_(statement.this)

return select_statement


def _is_create_table_ddl(statement: sqlglot.exp.Expression) -> bool:
return isinstance(statement, sqlglot.exp.Create) and isinstance(
Expand Down Expand Up @@ -955,7 +1002,7 @@ def _sqlglot_lineage_inner(
# Fetch schema info for the relevant tables.
table_name_urn_mapping: Dict[_TableName, str] = {}
table_name_schema_mapping: Dict[_TableName, SchemaInfo] = {}
for table in itertools.chain(tables, modified):
for table in tables | modified:
# For select statements, qualification will be a no-op. For other statements, this
# is where the qualification actually happens.
qualified_table = table.qualified(
Expand All @@ -971,7 +1018,7 @@ def _sqlglot_lineage_inner(
# Also include the original, non-qualified table name in the urn mapping.
table_name_urn_mapping[table] = urn

total_tables_discovered = len(tables) + len(modified)
total_tables_discovered = len(tables | modified)
total_schemas_resolved = len(table_name_schema_mapping)
debug_info = SqlParsingDebugInfo(
confidence=0.9 if total_tables_discovered == total_schemas_resolved
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table1,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table2,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "a",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.NumberType": {}
}
},
"native_column_type": "INT"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table1,PROD)",
"column": "a"
}
]
},
{
"downstream": {
"table": null,
"column": "b",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.NumberType": {}
}
},
"native_column_type": "INT"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table1,PROD)",
"column": "b"
}
]
},
{
"downstream": {
"table": null,
"column": "c",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.ArrayType": {}
}
},
"native_column_type": "INT[]"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table2,PROD)",
"column": "c"
}
]
}
]
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"query_type": "UPDATE",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.my_table,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.table1,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.table2,PROD)"
],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
{
"query_type": "UPDATE",
"in_tables": [],
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)"
],
"out_tables": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)"
],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"query_type": "UPDATE",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)"
],
"out_tables": [
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)"
],
"column_lineage": [
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "orderkey",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.NumberType": {}
}
},
"native_column_type": "DECIMAL"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)",
"column": "orderkey"
}
]
}
]
}
98 changes: 98 additions & 0 deletions metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,101 @@ def test_snowflake_update_from_table():
},
expected_file=RESOURCE_DIR / "test_snowflake_update_from_table.json",
)


def test_snowflake_update_self():
assert_sql_result(
"""
UPDATE snowflake_sample_data.tpch_sf1.orders
SET orderkey = orderkey + 1
""",
dialect="snowflake",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)": {
"orderkey": "NUMBER(38,0)",
"totalprice": "NUMBER(12,2)",
},
},
expected_file=RESOURCE_DIR / "test_snowflake_update_self.json",
)


def test_postgres_select_subquery():
assert_sql_result(
"""
SELECT
a,
b,
(SELECT c FROM table2 WHERE table2.id = table1.id) as c
FROM table1
""",
dialect="postgres",
default_db="my_db",
default_schema="my_schema",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table1,PROD)": {
"id": "INTEGER",
"a": "INTEGER",
"b": "INTEGER",
},
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table2,PROD)": {
"id": "INTEGER",
"c": "INTEGER",
},
},
expected_file=RESOURCE_DIR / "test_postgres_select_subquery.json",
)


@pytest.mark.skip(reason="We can't parse column-list syntax with sub-selects yet")
def test_postgres_update_subselect():
assert_sql_result(
"""
UPDATE accounts SET sales_person_name =
(SELECT name FROM employees
WHERE employees.id = accounts.sales_person_id)
""",
dialect="postgres",
default_db="my_db",
default_schema="my_schema",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.accounts,PROD)": {
"id": "INTEGER",
"sales_person_id": "INTEGER",
"sales_person_name": "VARCHAR(16777216)",
},
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.employees,PROD)": {
"id": "INTEGER",
"name": "VARCHAR(16777216)",
},
},
expected_file=RESOURCE_DIR / "test_postgres_update_subselect.json",
)


@pytest.mark.skip(reason="We can't parse column-list syntax with sub-selects yet")
def test_postgres_complex_update():
# Example query from the postgres docs:
# https://www.postgresql.org/docs/current/sql-update.html
assert_sql_result(
"""
UPDATE accounts SET (contact_first_name, contact_last_name) =
(SELECT first_name, last_name FROM employees
WHERE employees.id = accounts.sales_person);
""",
dialect="postgres",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.accounts,PROD)": {
"id": "INTEGER",
"contact_first_name": "VARCHAR(16777216)",
"contact_last_name": "VARCHAR(16777216)",
"sales_person": "INTEGER",
},
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.employees,PROD)": {
"id": "INTEGER",
"first_name": "VARCHAR(16777216)",
"last_name": "VARCHAR(16777216)",
},
},
expected_file=RESOURCE_DIR / "test_postgres_complex_update.json",
)

0 comments on commit ce0f36b

Please sign in to comment.