Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/jeb snowflake source quoting #1338

Merged
merged 10 commits into from
Mar 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 85 additions & 16 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import abc
import copy
import multiprocessing
import time

import agate
import pytz
Expand All @@ -13,11 +10,11 @@
import dbt.clients.agate_helper

from dbt.compat import abstractclassmethod, classmethod
from dbt.contracts.connection import Connection
from dbt.node_types import NodeType
from dbt.loader import GraphLoader
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.schema import Column
from dbt.utils import filter_null_values, translate_aliases
from dbt.utils import filter_null_values

from dbt.adapters.base.meta import AdapterMeta, available, available_raw, \
available_deprecated
Expand Down Expand Up @@ -94,6 +91,51 @@ def _utc(dt, source, field_name):
return dt.replace(tzinfo=pytz.UTC)


class SchemaSearchMap(dict):
"""A utility class to keep track of what information_schema tables to
search for what schemas
"""
def add(self, relation):
key = relation.information_schema_only()
if key not in self:
self[key] = set()
self[key].add(relation.schema.lower())

def search(self):
for information_schema_name, schemas in self.items():
for schema in schemas:
yield information_schema_name, schema

def schemas_searched(self):
result = set()
for information_schema_name, schemas in self.items():
result.update(
(information_schema_name.database, schema)
for schema in schemas
)
return result

def flatten(self):
new = self.__class__()

database = None
# iterate once to look for a database name
seen = {r.database.lower() for r in self if r.database}
if len(seen) > 1:
dbt.exceptions.raise_compiler_error(str(seen))
elif len(seen) == 1:
database = list(seen)[0]

for information_schema_name, schema in self.search():
new.add(information_schema_name.incorporate(
path={'database': database, 'schema': schema},
quote_policy={'database': False},
include_policy={'database': False},
))

return new


@six.add_metaclass(AdapterMeta)
class BaseAdapter(object):
"""The BaseAdapter provides an abstract base class for adapters.
Expand Down Expand Up @@ -237,24 +279,44 @@ def _relations_filter_table(cls, table, schemas):
"""
return table.where(_relations_filter_schemas(schemas))

def _get_cache_schemas(self, manifest, exec_only=False):
"""Get a mapping of each node's "information_schema" relations to a
set of all schemas expected in that information_schema.

There may be keys that are technically duplicates on the database side,
for example all of '"foo", 'foo', '"FOO"' and 'FOO' could coexist as
databases, and values could overlap as appropriate. All values are
lowercase strings.
"""
info_schema_name_map = SchemaSearchMap()
for node in manifest.nodes.values():
if exec_only and node.resource_type not in NodeType.executable():
continue
relation = self.Relation.create_from(self.config, node)
info_schema_name_map.add(relation)
# result is a map whose keys are information_schema Relations without
# identifiers that have appropriate database prefixes, and whose values
# are sets of lowercase schema names that are valid members of those
# schemas
return info_schema_name_map

def _relations_cache_for_schemas(self, manifest):
"""Populate the relations cache for the given schemas. Returns an
iteratble of the schemas populated, as strings.
"""
if not dbt.flags.USE_CACHE:
return

schemas = manifest.get_used_schemas()

relations = []
# add all relations
for db, schema in schemas:
info_schema_name_map = self._get_cache_schemas(manifest,
exec_only=True)
for db, schema in info_schema_name_map.search():
for relation in self.list_relations_without_caching(db, schema):
self.cache.add(relation)

# it's possible that there were no relations in some schemas. We want
# to insert the schemas we query into the cache's `.schemas` attribute
# so we can check it later
self.cache.update_schemas(schemas)
self.cache.update_schemas(info_schema_name_map.schemas_searched())

def set_relations_cache(self, manifest, clear=False):
"""Run a query that gets a populated cache of the relations in the
Expand Down Expand Up @@ -415,13 +477,14 @@ def expand_column_types(self, goal, current, model_name=None):
)

@abc.abstractmethod
def list_relations_without_caching(self, database, schema,
def list_relations_without_caching(self, information_schema, schema,
model_name=None):
"""List relations in the given schema, bypassing the cache.

This is used as the underlying behavior to fill the cache.

:param str database: The name of the database to list relations from.
:param Relation information_schema: The information schema to list
relations from.
:param str schema: The name of the schema to list relations from.
:param Optional[str] model_name: The name of the model to use for the
connection.
Expand Down Expand Up @@ -495,10 +558,15 @@ def list_relations(self, database, schema, model_name=None):
if self._schema_is_cached(database, schema, model_name):
return self.cache.get_relations(database, schema)

information_schema = self.Relation.create(
database=database,
schema=schema,
model_name='').information_schema()

# we can't build the relations cache because we don't have a
# manifest so we can't run any operations.
relations = self.list_relations_without_caching(
database, schema, model_name=model_name
information_schema, schema, model_name=model_name
)

logger.debug('with schema={}, model_name={}, relations={}'
Expand Down Expand Up @@ -802,10 +870,11 @@ def get_catalog(self, manifest):
"""Get the catalog for this manifest by running the get catalog macro.
Returns an agate.Table of catalog information.
"""
information_schemas = list(self._get_cache_schemas(manifest).keys())
# make it a list so macros can index into it.
context = {'databases': list(manifest.get_used_databases())}
kwargs = {'information_schemas': information_schemas}
table = self.execute_macro(GET_CATALOG_MACRO_NAME,
context_override=context,
kwargs=kwargs,
release=True)

results = self._catalog_filter_table(table, manifest)
Expand Down
53 changes: 46 additions & 7 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dbt.api import APIObject
from dbt.utils import filter_null_values
from dbt.node_types import NodeType

import dbt.exceptions

Expand Down Expand Up @@ -30,15 +31,15 @@ class BaseRelation(APIObject):
'database': True,
'schema': True,
'identifier': True
}
},
}

PATH_SCHEMA = {
'type': 'object',
'properties': {
'database': {'type': ['string', 'null']},
'schema': {'type': ['string', 'null']},
'identifier': {'type': 'string'},
'identifier': {'type': ['string', 'null']},
},
'required': ['database', 'schema', 'identifier'],
}
Expand Down Expand Up @@ -135,6 +136,36 @@ def include(self, database=None, schema=None, identifier=None):

return self.incorporate(include_policy=policy)

def information_schema(self, identifier=None):
include_db = self.database is not None
include_policy = filter_null_values({
'database': include_db,
'schema': True,
'identifier': identifier is not None
})
quote_policy = filter_null_values({
'database': self.quote_policy['database'],
'schema': False,
'identifier': False,
})

path_update = {
'schema': 'information_schema',
'identifier': identifier
}

return self.incorporate(
quote_policy=quote_policy,
include_policy=include_policy,
path=path_update,
table_name=identifier)

def information_schema_only(self):
return self.information_schema()

def information_schema_table(self, identifier):
return self.information_schema(identifier)

def render(self, use_table_name=True):
parts = []

Expand Down Expand Up @@ -174,15 +205,16 @@ def quoted(self, identifier):

@classmethod
def create_from_source(cls, source, **kwargs):
quote_policy = dbt.utils.deep_merge(
cls.DEFAULTS['quote_policy'],
source.quoting,
kwargs.get('quote_policy', {})
)
return cls.create(
database=source.database,
schema=source.schema,
identifier=source.identifier,
quote_policy={
'database': True,
'schema': True,
'identifier': True,
},
quote_policy=quote_policy,
**kwargs
)

Expand All @@ -202,6 +234,13 @@ def create_from_node(cls, config, node, table_name=None, quote_policy=None,
quote_policy=quote_policy,
**kwargs)

@classmethod
def create_from(cls, config, node, **kwargs):
if node.resource_type == NodeType.Source:
return cls.create_from_source(node, **kwargs)
else:
return cls.create_from_node(config, node, **kwargs)

@classmethod
def create(cls, database=None, schema=None,
identifier=None, table_name=None,
Expand Down
17 changes: 9 additions & 8 deletions core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import abc
import time

import agate
import six

import dbt.clients.agate_helper
import dbt.exceptions
import dbt.flags
from dbt.adapters.base import BaseAdapter, available
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.compat import abstractclassmethod


LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
Expand Down Expand Up @@ -196,11 +191,12 @@ def drop_schema(self, database, schema, model_name=None):
kwargs=kwargs,
connection_name=model_name)

def list_relations_without_caching(self, database, schema,
def list_relations_without_caching(self, information_schema, schema,
model_name=None):
kwargs = {'information_schema': information_schema, 'schema': schema}
results = self.execute_macro(
LIST_RELATIONS_MACRO_NAME,
kwargs={'database': database, 'schema': schema},
kwargs=kwargs,
connection_name=model_name,
release=True
)
Expand Down Expand Up @@ -236,9 +232,14 @@ def list_schemas(self, database, model_name=None):
return [row[0] for row in results]

def check_schema_exists(self, database, schema, model_name=None):
information_schema = self.Relation.create(
database=database, schema=schema
).information_schema()

kwargs = {'information_schema': information_schema, 'schema': schema}
results = self.execute_macro(
CHECK_SCHEMA_EXISTS_MACRO_NAME,
kwargs={'database': database, 'schema': schema},
kwargs=kwargs,
connection_name=model_name
)
return results[0][0] > 0
5 changes: 5 additions & 0 deletions core/dbt/context/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def __init__(self, adapter):
def __getattr__(self, key):
return getattr(self.relation_type, key)

def create_from_source(self, *args, **kwargs):
# bypass our create when creating from source so as not to mess up
# the source quoting
return self.relation_type.create_from_source(*args, **kwargs)

def create(self, *args, **kwargs):
kwargs['quote_policy'] = dbt.utils.merge(
self.quoting_config,
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
PARSED_MACRO_CONTRACT, PARSED_DOCUMENTATION_CONTRACT, \
PARSED_SOURCE_DEFINITION_CONTRACT
from dbt.contracts.graph.compiled import COMPILED_NODE_CONTRACT, CompiledNode
from dbt.exceptions import ValidationException
from dbt.exceptions import raise_duplicate_resource_name
from dbt.node_types import NodeType
from dbt.logger import GLOBAL_LOGGER as logger
from dbt import tracking
Expand Down Expand Up @@ -401,10 +401,11 @@ def __getattr__(self, name):
type(self).__name__, name)
)

def get_used_schemas(self):
def get_used_schemas(self, resource_types=None):
return frozenset({
(node.database, node.schema)
for node in self.nodes.values()
if not resource_types or node.resource_type in resource_types
})

def get_used_databases(self):
Expand Down
Loading