Skip to content

Commit

Permalink
Merge pull request #1402 from fishtown-analytics/fix/quote-databases-…
Browse files Browse the repository at this point in the history
…properly

Quote databases properly (#1396)
  • Loading branch information
beckjake authored Apr 30, 2019
2 parents ad1fcbe + 2834f2d commit 5a3e3ba
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 12 deletions.
6 changes: 4 additions & 2 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,9 @@ def list_relations(self, database, schema, model_name=None):
information_schema = self.Relation.create(
database=database,
schema=schema,
model_name='').information_schema()
model_name='',
quote_policy=self.config.quoting
).information_schema()

# we can't build the relations cache because we don't have a
# manifest so we can't run any operations.
Expand All @@ -581,7 +583,7 @@ def _make_match_kwargs(self, database, schema, identifier):
if schema is not None and quoting['schema'] is False:
schema = schema.lower()

if database is not None and quoting['schema'] is False:
if database is not None and quoting['database'] is False:
database = database.lower()

return filter_null_values({
Expand Down
4 changes: 3 additions & 1 deletion core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def list_relations_without_caching(self, information_schema, schema,

relations = []
quote_policy = {
'database': True,
'schema': True,
'identifier': True
}
Expand Down Expand Up @@ -233,7 +234,8 @@ def list_schemas(self, database, model_name=None):

def check_schema_exists(self, database, schema, model_name=None):
information_schema = self.Relation.create(
database=database, schema=schema
database=database, schema=schema,
quote_policy=self.config.quoting
).information_schema()

kwargs = {'information_schema': information_schema, 'schema': schema}
Expand Down
5 changes: 4 additions & 1 deletion plugins/postgres/dbt/adapters/postgres/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def date_function(cls):

@available_raw
def verify_database(self, database):
database = database.strip('"')
if database.startswith('"'):
database = database.strip('"')
else:
database = database.lower()
expected = self.config.credentials.database
if database != expected:
raise dbt.exceptions.NotImplementedException(
Expand Down
2 changes: 2 additions & 0 deletions test.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ SNOWFLAKE_TEST_ACCOUNT=
SNOWFLAKE_TEST_USER=
SNOWFLAKE_TEST_PASSWORD=
SNOWFLAKE_TEST_DATABASE=
SNOWFLAKE_TEST_ALT_DATABASE=
SNOWFLAKE_TEST_QUOTED_DATABASE=
SNOWFLAKE_TEST_WAREHOUSE=

BIGQUERY_TYPE=
Expand Down
28 changes: 28 additions & 0 deletions test/integration/001_simple_copy_test/test_simple_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def models(self):


class TestSimpleCopy(BaseTestSimpleCopy):

@use_profile("postgres")
def test__postgres__simple_copy(self):
self.use_default_project({"data-paths": [self.dir("seed-initial")]})
Expand Down Expand Up @@ -83,6 +84,12 @@ def test__snowflake__simple_copy(self):

self.assertManyTablesEqual(["SEED", "VIEW_MODEL", "INCREMENTAL", "MATERIALIZED"])

self.use_default_project({
"test-paths": [self.dir("tests")],
"data-paths": [self.dir("seed-update")],
})
self.run_dbt(['test'])

@use_profile("snowflake")
def test__snowflake__simple_copy__quoting_off(self):
self.use_default_project({
Expand All @@ -108,6 +115,13 @@ def test__snowflake__simple_copy__quoting_off(self):

self.assertManyTablesEqual(["SEED", "VIEW_MODEL", "INCREMENTAL", "MATERIALIZED"])

self.use_default_project({
"test-paths": [self.dir("tests")],
"data-paths": [self.dir("seed-update")],
"quoting": {"identifier": False},
})
self.run_dbt(['test'])

@use_profile("snowflake")
def test__snowflake__seed__quoting_switch(self):
self.use_default_project({
Expand All @@ -124,6 +138,12 @@ def test__snowflake__seed__quoting_switch(self):
})
results = self.run_dbt(["seed"], expect_pass=False)

self.use_default_project({
"test-paths": [self.dir("tests")],
"data-paths": [self.dir("seed-initial")],
})
self.run_dbt(['test'])

@use_profile("bigquery")
def test__bigquery__simple_copy(self):
self.use_default_project({"data-paths": [self.dir("seed-initial")]})
Expand Down Expand Up @@ -181,6 +201,8 @@ def test__snowflake__simple_copy__quoting_on(self):

self.assertManyTablesEqual(["seed", "view_model", "incremental", "materialized"])

# can't run the test as this one's identifiers will be the wrong case


class BaseLowercasedSchemaTest(BaseTestSimpleCopy):
def unique_schema(self):
Expand Down Expand Up @@ -210,6 +232,12 @@ def test__snowflake__simple_copy(self):

self.assertManyTablesEqual(["SEED", "VIEW_MODEL", "INCREMENTAL", "MATERIALIZED"])

self.use_default_project({
"test-paths": [self.dir("tests")],
"data-paths": [self.dir("seed-update")],
})
self.run_dbt(['test'])


class TestSnowflakeSimpleLowercasedSchemaQuoted(BaseLowercasedSchemaTest):
@property
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{%- set tgt = ref('seed') -%}
{%- set got = adapter.get_relation(database=tgt.database, schema=tgt.schema, identifier=tgt.identifier) | string -%}
{% set replaced = got.replace('"', '-') %}
{% set expected = "-" + tgt.database.upper() + '-.-' + tgt.schema.upper() + '-.-' + tgt.identifier.upper() + '-' %}

with cte as (
select '{{ replaced }}' as name
)
select * from cte where name not like '{{ expected }}'
6 changes: 6 additions & 0 deletions test/integration/040_override_database_test/models/view_1.sql
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
{#
We are running against a database that must be quoted.
These calls ensure that we trigger an error if we're failing to quote at parse-time
#}
{% do adapter.already_exists(this.schema, this.table) %}
{% do adapter.get_relation(this.database, this.schema, this.table) %}
select * from {{ ref('seed') }}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{{
config(database=var('alternate_db'))
}}

select * from {{ ref('seed') }}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from nose.plugins.attrib import attr
from test.integration.base import DBTIntegrationTest

import os


class BaseOverrideDatabase(DBTIntegrationTest):
setup_alternate_db = True
Expand All @@ -12,6 +14,45 @@ def schema(self):
def models(self):
return "test/integration/040_override_database_test/models"

@property
def alternative_database(self):
if self.adapter_type == 'snowflake':
return os.getenv('SNOWFLAKE_TEST_DATABASE')
else:
return super(BaseOverrideDatabase, self).alternative_database

def snowflake_profile(self):
return {
'config': {
'send_anonymous_usage_stats': False
},
'test': {
'outputs': {
'default2': {
'type': 'snowflake',
'threads': 4,
'account': os.getenv('SNOWFLAKE_TEST_ACCOUNT'),
'user': os.getenv('SNOWFLAKE_TEST_USER'),
'password': os.getenv('SNOWFLAKE_TEST_PASSWORD'),
'database': os.getenv('SNOWFLAKE_TEST_QUOTED_DATABASE'),
'schema': self.unique_schema(),
'warehouse': os.getenv('SNOWFLAKE_TEST_WAREHOUSE'),
},
'noaccess': {
'type': 'snowflake',
'threads': 4,
'account': os.getenv('SNOWFLAKE_TEST_ACCOUNT'),
'user': 'noaccess',
'password': 'password',
'database': os.getenv('SNOWFLAKE_TEST_DATABASE'),
'schema': self.unique_schema(),
'warehouse': os.getenv('SNOWFLAKE_TEST_WAREHOUSE'),
}
},
'target': 'default2'
}
}

@property
def project_config(self):
return {
Expand All @@ -20,9 +61,15 @@ def project_config(self):
'vars': {
'alternate_db': self.alternative_database,
},
},
'quoting': {
'database': True,
}
}

def run_dbt_notstrict(self, args):
return self.run_dbt(args, strict=False)


class TestModelOverride(BaseOverrideDatabase):
def run_database_override(self):
Expand All @@ -31,9 +78,9 @@ def run_database_override(self):
else:
func = lambda x: x

self.run_dbt(['seed'])
self.run_dbt_notstrict(['seed'])

self.assertEqual(len(self.run_dbt(['run'])), 4)
self.assertEqual(len(self.run_dbt_notstrict(['run'])), 4)
self.assertManyRelationsEqual([
(func('seed'), self.unique_schema(), self.default_database),
(func('view_2'), self.unique_schema(), self.alternative_database),
Expand Down Expand Up @@ -71,9 +118,9 @@ def run_database_override(self):
},
}
})
self.run_dbt(['seed'])
self.run_dbt_notstrict(['seed'])

self.assertEqual(len(self.run_dbt(['run'])), 4)
self.assertEqual(len(self.run_dbt_notstrict(['run'])), 4)
self.assertManyRelationsEqual([
(func('seed'), self.unique_schema(), self.default_database),
(func('view_2'), self.unique_schema(), self.alternative_database),
Expand Down Expand Up @@ -101,9 +148,9 @@ def run_database_override(self):
self.use_default_project({
'seeds': {'database': self.alternative_database}
})
self.run_dbt(['seed'])
self.run_dbt_notstrict(['seed'])

self.assertEqual(len(self.run_dbt(['run'])), 4)
self.assertEqual(len(self.run_dbt_notstrict(['run'])), 4)
self.assertManyRelationsEqual([
(func('seed'), self.unique_schema(), self.alternative_database),
(func('view_2'), self.unique_schema(), self.alternative_database),
Expand Down
4 changes: 3 additions & 1 deletion test/integration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,9 @@ def get_many_table_columns(self, tables, schema, database=None):
and ({table_filter})
order by column_name asc"""

db_string = '' if database is None else database + '.'
db_string = ''
if database:
db_string = self.quote_as_configured(database, 'database') + '.'

table_filters_s = " OR ".join(
self._ilike('table_name', table.replace('"', ''))
Expand Down

0 comments on commit 5a3e3ba

Please sign in to comment.