Skip to content

Commit

Permalink
Merge pull request #2001 from fishtown-analytics/fix/snapshot-check-a…
Browse files Browse the repository at this point in the history
…ll-added-column

Fix snapshot check all with an added column (#1797)
  • Loading branch information
beckjake authored Dec 13, 2019
2 parents a702f58 + ab4925f commit b151e2a
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,32 @@
) %}
{% endmacro %}


{% macro snapshot_check_all_get_existing_columns(node, target_exists) -%}
{%- set query_columns = get_columns_in_query(node['injected_sql']) -%}
{%- if not target_exists -%}
{# no table yet -> return whatever the query does #}
{{ return([false, query_columns]) }}
{%- endif -%}
{# handle any schema changes #}
{%- set target_table = node.get('alias', node.get('name')) -%}
{%- set target_relation = adapter.get_relation(database=node.database, schema=node.schema, identifier=target_table) -%}
{%- set existing_cols = get_columns_in_query('select * from ' ~ target_relation) -%}
{%- set ns = namespace() -%} {# handle for-loop scoping with a namespace #}
{%- set ns.column_added = false -%}

{%- set intersection = [] -%}
{%- for col in query_columns -%}
{%- if col in existing_cols -%}
{%- do intersection.append(col) -%}
{%- else -%}
{% set ns.column_added = true %}
{%- endif -%}
{%- endfor -%}
{{ return([ns.column_added, intersection]) }}
{%- endmacro %}


{% macro snapshot_check_strategy(node, snapshotted_rel, current_rel, config, target_exists) %}
{% set check_cols_config = config['check_cols'] %}
{% set primary_key = config['unique_key'] %}
Expand All @@ -107,24 +133,29 @@
{%- endif %}
{% set updated_at = snapshot_string_as_time(now) %}

{% set column_added = false %}

{% if check_cols_config == 'all' %}
{% set check_cols = get_columns_in_query(node['injected_sql']) %}
{% set column_added, check_cols = snapshot_check_all_get_existing_columns(node, target_exists) %}
{% elif check_cols_config is iterable and (check_cols_config | length) > 0 %}
{% set check_cols = check_cols_config %}
{% else %}
{% do exceptions.raise_compiler_error("Invalid value for 'check_cols': " ~ check_cols_config) %}
{% endif %}

{% set row_changed_expr -%}
(
{% for col in check_cols %}
{{ snapshotted_rel }}.{{ col }} != {{ current_rel }}.{{ col }}
or
({{ snapshotted_rel }}.{{ col }} is null) != ({{ current_rel }}.{{ col }} is null)
{%- if not loop.last %} or {% endif %}

{% endfor %}
)
{%- set row_changed_expr -%}
(
{%- if column_added -%}
TRUE
{%- else -%}
{%- for col in check_cols -%}
{{ snapshotted_rel }}.{{ col }} != {{ current_rel }}.{{ col }}
or
({{ snapshotted_rel }}.{{ col }} is null) != ({{ current_rel }}.{{ col }} is null)
{%- if not loop.last %} or {% endif -%}
{%- endfor -%}
{%- endif -%}
)
{%- endset %}

{% set scd_id_expr = snapshot_hash_arguments([primary_key, updated_at]) %}
Expand Down
4 changes: 4 additions & 0 deletions test/integration/004_simple_snapshot_test/data/seed.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
id,first_name
1,Judith
2,Arthur
3,Rachel
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
id,first_name,last_name
1,Judith,Kennedy
2,Arthur,Kelly
3,Rachel,Moreno
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{% snapshot my_snapshot %}
{{ config(check_cols='all', unique_key='id', strategy='check', target_database=database, target_schema=schema) }}
select * from {{ ref(var('seed_name', 'seed')) }}
{% endsnapshot %}
85 changes: 81 additions & 4 deletions test/integration/004_simple_snapshot_test/test_simple_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,79 @@ def test__redshift__simple_snapshot(self):
self.assert_expected()


class TestSimpleColumnSnapshotFiles(DBTIntegrationTest):

@property
def schema(self):
return "simple_snapshot_004"

@property
def models(self):
return "models-checkall"

@property
def project_config(self):
return {
'data-paths': ['data'],
'macro-paths': ['custom-snapshot-macros', 'macros'],
'snapshot-paths': ['test-snapshots-checkall'],
'seeds': {
'quote_columns': False,
}
}

def _run_snapshot_test(self):
self.run_dbt(['seed'])
self.run_dbt(['snapshot'])
database = self.default_database
if self.adapter_type == 'bigquery':
database = self.adapter.quote(database)
results = self.run_sql(
'select * from {}.{}.my_snapshot'.format(database, self.unique_schema()),
fetch='all'
)
self.assertEqual(len(results), 3)
for result in results:
self.assertEqual(len(result), 6)

self.run_dbt(['snapshot', '--vars', '{seed_name: seed_newcol}'])
results = self.run_sql(
'select * from {}.{}.my_snapshot where last_name is not NULL'.format(database, self.unique_schema()),
fetch='all'
)
self.assertEqual(len(results), 3)

for result in results:
# new column
self.assertEqual(len(result), 7)
self.assertIsNotNone(result[-1])

results = self.run_sql(
'select * from {}.{}.my_snapshot where last_name is NULL'.format(database, self.unique_schema()),
fetch='all'
)
self.assertEqual(len(results), 3)
for result in results:
# new column
self.assertEqual(len(result), 7)

@use_profile('postgres')
def test_postgres_renamed_source(self):
self._run_snapshot_test()

@use_profile('snowflake')
def test_snowflake_renamed_source(self):
self._run_snapshot_test()

@use_profile('redshift')
def test_redshift_renamed_source(self):
self._run_snapshot_test()

@use_profile('bigquery')
def test_bigquery_renamed_source(self):
self._run_snapshot_test()


class TestCustomSnapshotFiles(BaseSimpleSnapshotTest):
@property
def project_config(self):
Expand Down Expand Up @@ -519,10 +592,8 @@ def test__bigquery__snapshot_with_new_field(self):
# This adds new fields to the source table, and updates the expected snapshot output accordingly
self.run_sql_file("add_column_to_source_bq.sql")

# this should fail because `check="all"` will try to compare the nested field
self.run_dbt(['snapshot'], expect_pass=False)

self.run_dbt(["snapshot", '--select', 'snapshot_actual'])
# check_cols='all' will replace the changed field
self.run_dbt(['snapshot'])

# A more thorough test would assert that snapshotted == expected, but BigQuery does not support the
# "EXCEPT DISTINCT" operator on nested fields! Instead, just check that schemas are congruent.
Expand All @@ -537,9 +608,15 @@ def test__bigquery__snapshot_with_new_field(self):
schema=self.unique_schema(),
table='snapshot_actual'
)
snapshotted_all_cols = self.get_table_columns(
database=self.default_database,
schema=self.unique_schema(),
table='snapshot_checkall'
)

self.assertTrue(len(expected_cols) > 0, "source table does not exist -- bad test")
self.assertEqual(len(expected_cols), len(snapshotted_cols), "actual and expected column lengths are different")
self.assertEqual(len(expected_cols), len(snapshotted_all_cols))

for (expected_col, actual_col) in zip(expected_cols, snapshotted_cols):
expected_name, expected_type, _ = expected_col
Expand Down
2 changes: 1 addition & 1 deletion test/integration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def run_sql_common(self, sql, fetch, conn):
else:
return
except BaseException as e:
if conn.handle and not conn.handle.is_closed():
if conn.handle and not conn.handle.closed:
conn.handle.rollback()
print(sql)
print(e)
Expand Down

0 comments on commit b151e2a

Please sign in to comment.