Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

[AP-591] Use SHOW SCHEMAS|TABLES|COLUMNS instead of INFORMATION_SCHEMA #14

Merged
merged 6 commits into from
Mar 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 132 additions & 29 deletions tap_snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import collections
import copy
import itertools
import re
import sys
import logging

import singer
import singer.metrics as metrics
import singer.schema
import snowflake.connector
from singer import metadata
from singer import utils
from singer.catalog import Catalog, CatalogEntry
Expand All @@ -21,6 +24,11 @@

LOGGER = singer.get_logger('tap_snowflake')

# Max number of rows that a SHOW SCHEMAS|TABLES|COLUMNS can return.
# If more than this number of rows returned then tap-snowflake will raise TooManyRecordsException
SHOW_COMMAND_MAX_ROWS = 9999


# Tone down snowflake connector logs noise
logging.getLogger('snowflake.connector').setLevel(logging.WARNING)

Expand Down Expand Up @@ -110,46 +118,140 @@ def create_column_metadata(cols):
return metadata.to_list(mdata)


def get_databases(snowflake_conn):
"""Get snowflake databases"""
databases = snowflake_conn.query('SHOW DATABASES', max_records=SHOW_COMMAND_MAX_ROWS)

# Return only the name of databases as a list
return [db['name'] for db in databases]


def get_schemas(snowflake_conn, database):
"""Get schemas of a database"""
schemas = []
try:
schemas = snowflake_conn.query(f'SHOW SCHEMAS IN DATABASE {database}', max_records=SHOW_COMMAND_MAX_ROWS)

# Get only the name of schemas as a list
schemas = [schema['name'] for schema in schemas]

# Catch exception when schema not exists and SHOW SCHEMAS throws a ProgrammingError
# Regexp to extract snowflake error code and message from the exception message
# Do nothing if schema not exists
except snowflake.connector.errors.ProgrammingError as exc:
# pylint: disable=anomalous-backslash-in-string
if re.match('.*\(02000\):.*\n.*does not exist.*', str(sys.exc_info()[1])):
pass
else:
raise exc

return schemas


def get_table_columns(snowflake_conn, database, table_schemas=None, table_name=None):
"""Get column definitions for every table in specific schemas(s)

It's using SHOW commands instead of INFORMATION_SCHEMA views bucause information_schemas views are slow
and can cause unexpected exception of:
Information schema query returned too much data. Please repeat query with more selective predicates.
"""
table_columns = []
if table_schemas or table_name:
for schema in table_schemas:
queries = []

LOGGER.info('Getting schema information for %s.%s...', database, schema)

# Get column data types by SHOW commands
show_tables = f'SHOW TABLES IN SCHEMA {database}.{schema}'
show_views = f'SHOW TABLES IN SCHEMA {database}.{schema}'
show_columns = f'SHOW COLUMNS IN SCHEMA {database}.{schema}'

# Convert output of SHOW commands to tables and use SQL joins to get every required information
select = f"""
WITH
show_tables AS (SELECT * FROM TABLE(RESULT_SCAN(LAST_QUERY_ID(-3)))),
show_views AS (SELECT * FROM TABLE(RESULT_SCAN(LAST_QUERY_ID(-2)))),
show_columns AS (SELECT * FROM TABLE(RESULT_SCAN(LAST_QUERY_ID(-1))))
SELECT show_columns."database_name" AS table_catalog
,show_columns."schema_name" AS table_schema
,show_columns."table_name" AS table_name
,CASE
WHEN show_tables."name" IS NOT NULL THEN 'BASE TABLE'
ELSE 'VIEW'
END table_type
,show_tables."rows" AS row_count
,show_columns."column_name" AS column_name
-- ----------------------------------------------------------------------------------------
-- Character and numeric columns display their generic data type rather than their defined
-- data type (i.e. TEXT for all character types, FIXED for all fixed-point numeric types,
-- and REAL for all floating-point numeric types).
--
-- Further info at https://docs.snowflake.net/manuals/sql-reference/sql/show-columns.html
-- ----------------------------------------------------------------------------------------
,CASE PARSE_JSON(show_columns."data_type"):type::varchar
WHEN 'FIXED' THEN 'NUMBER'
WHEN 'REAL' THEN 'FLOAT'
ELSE PARSE_JSON("data_type"):type::varchar
END data_type
,PARSE_JSON(show_columns."data_type"):length::number AS character_maximum_length
,PARSE_JSON(show_columns."data_type"):precision::number AS numeric_precision
,PARSE_JSON(show_columns."data_type"):scale::number AS numeric_scale
FROM show_columns
LEFT JOIN show_tables
ON show_tables."database_name" = show_columns."database_name"
AND show_tables."schema_name" = show_columns."schema_name"
AND show_tables."name" = show_columns."table_name"
LEFT JOIN show_views
ON show_views."database_name" = show_columns."database_name"
AND show_views."schema_name" = show_columns."schema_name"
AND show_views."name" = show_columns."table_name"
"""
queries.extend([show_tables, show_views, show_columns, select])

# Run everything in one transaction
try:
columns = snowflake_conn.query(queries, max_records=SHOW_COMMAND_MAX_ROWS)
table_columns.extend(columns)

# Catch exception when schema not exists and SHOW COLUMNS throws a ProgrammingError
# Regexp to extract snowflake error code and message from the exception message
# Do nothing if schema not exists
except snowflake.connector.errors.ProgrammingError as exc:
# pylint: disable=anomalous-backslash-in-string
if re.match('.*\(02000\):.*\n.*does not exist.*', str(sys.exc_info()[1])):
pass
else:
raise exc

return table_columns


def discover_catalog(snowflake_conn, config):
"""Returns a Catalog describing the structure of the database."""
filter_dbs_config = config.get('filter_dbs')
filter_schemas_config = config.get('filter_schemas')
databases = []
schemas = []

# Get databases
sql_columns = []
if filter_dbs_config:
filter_dbs_clause = ','.join(f"LOWER('{db}')" for db in filter_dbs_config.split(','))

table_db_clause = f'LOWER(t.table_catalog) IN ({filter_dbs_clause})'
databases = filter_dbs_config.split(',')
else:
table_db_clause = '1 = 1'
databases = get_databases(snowflake_conn)
for database in databases:

if filter_schemas_config:
filter_schemas_clause = ','.join([f"LOWER('{schema}')" for schema in filter_schemas_config.split(',')])
# Get schemas
if filter_schemas_config:
schemas = filter_schemas_config.split(',')
else:
schemas = get_schemas(snowflake_conn, database)

table_schema_clause = f'LOWER(t.table_schema) IN ({filter_schemas_clause})'
else:
table_schema_clause = "LOWER(t.table_schema) NOT IN ('information_schema')"
table_columns = get_table_columns(snowflake_conn, database, schemas)
sql_columns.extend(table_columns)

table_info = {}
sql_columns = snowflake_conn.query("""
SELECT t.table_catalog,
t.table_schema,
t.table_name,
t.table_type,
t.row_count,
c.column_name,
c.data_type,
c.character_maximum_length,
c.numeric_precision,
c.numeric_scale
FROM information_schema.tables t,
information_schema.columns c
WHERE t.table_catalog = c.table_catalog
AND t.table_schema = c.table_schema
AND t.table_name = c.table_name
AND {}
AND {}
""".format(table_db_clause, table_schema_clause))

columns = []
for sql_col in sql_columns:
catalog = sql_col['TABLE_CATALOG']
Expand Down Expand Up @@ -217,6 +319,7 @@ def discover_catalog(snowflake_conn, config):
def do_discover(snowflake_conn, config):
discover_catalog(snowflake_conn, config).dump()


# pylint: disable=fixme
# TODO: Maybe put in a singer-db-utils library.
def desired_columns(selected, table_schema):
Expand Down
35 changes: 26 additions & 9 deletions tap_snowflake/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
LOGGER = singer.get_logger('tap_snowflake')


class TooManyRecordsException(Exception):
"""Exception to raise when query returns more records than max_records"""


def retry_pattern():
"""Retry pattern decorator used when connecting to snowflake
"""
Expand Down Expand Up @@ -79,17 +83,30 @@ def connect_with_backoff(self):
return self.open_connection()


def query(self, query, params=None):
def query(self, query, params=None, max_records=0):
"""Run a query in snowflake"""
LOGGER.info('SNOWFLAKE - Running query: %s', query)
result = []
with self.connect_with_backoff() as connection:
with connection.cursor(snowflake.connector.DictCursor) as cur:
cur.execute(
query,
params
)
queries = []

# Run every query in one transaction if query is a list of SQL
if isinstance(query, list):
queries.append('START TRANSACTION')
queries.extend(query)
else:
queries = [query]

for sql in queries:
LOGGER.debug('SNOWFLAKE - Running query: %s', sql)
cur.execute(sql, params)

# Raise exception if returned rows greater than max allowed records
if 0 < max_records < cur.rowcount:
raise TooManyRecordsException(
f'Query returned too many records. This query can return max {max_records} records')

if cur.rowcount > 0:
return cur.fetchall()
if cur.rowcount > 0:
result = cur.fetchall()

return []
return result
45 changes: 23 additions & 22 deletions tests/integration/test_tap_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,25 @@

from singer.schema import Schema


try:
import tests.utils as test_utils
except ImportError:
import utils as test_utils

LOGGER = singer.get_logger('tap_snowflake_tests')

SCHEMA_NAME='tap_snowflake_test'
SCHEMA_NAME = 'tap_snowflake_test'

SINGER_MESSAGES = []


def accumulate_singer_messages(message):
SINGER_MESSAGES.append(message)


singer.write_message = accumulate_singer_messages


class TestTypeMapping(unittest.TestCase):

@classmethod
Expand Down Expand Up @@ -181,10 +183,10 @@ def test_row_to_singer_record(self):

# Convert the exported data to singer JSON
record_message = common.row_to_singer_record(catalog_entry=catalog_entry,
version=1,
row=row,
columns=columns,
time_extracted=singer.utils.now())
version=1,
row=row,
columns=columns,
time_extracted=singer.utils.now())

# Convert to formatted JSON
formatted_record = singer.messages.format_message(record_message)
Expand All @@ -193,21 +195,21 @@ def test_row_to_singer_record(self):
self.assertEquals(json.loads(formatted_record)['type'], 'RECORD')
self.assertEquals(json.loads(formatted_record)['stream'], 'TEST_TYPE_MAPPING')
self.assertEquals(json.loads(formatted_record)['record'],
{
'C_PK': 1,
'C_DECIMAL': 12345,
'C_DECIMAL_2': 123456789.12,
'C_SMALLINT': 123,
'C_INT': 12345,
'C_BIGINT': 1234567890,
'C_FLOAT': 123.123,
'C_DOUBLE': 123.123,
'C_DATE': '2019-08-01T00:00:00+00:00',
'C_DATETIME': '2019-08-01T17:23:59+00:00',
'C_TIME': '17:23:59',
'C_BINARY': '62696E617279',
'C_VARBINARY': '76617262696E617279'
})
{
'C_PK': 1,
'C_DECIMAL': 12345,
'C_DECIMAL_2': 123456789.12,
'C_SMALLINT': 123,
'C_INT': 12345,
'C_BIGINT': 1234567890,
'C_FLOAT': 123.123,
'C_DOUBLE': 123.123,
'C_DATE': '2019-08-01T00:00:00+00:00',
'C_DATETIME': '2019-08-01T17:23:59+00:00',
'C_TIME': '17:23:59',
'C_BINARY': '62696E617279',
'C_VARBINARY': '76617262696E617279'
})


class TestSelectsAppropriateColumns(unittest.TestCase):
Expand All @@ -225,4 +227,3 @@ def runTest(self):
self.assertEqual(got_cols,
set(['a', 'c']),
'Keep automatic as well as selected, available columns.')