Skip to content

Commit

Permalink
Fix docs generation for cross-db sources in REDSHIFT RA3 node (#3408)
Browse files Browse the repository at this point in the history
* Fix docs generating for cross-db sources

* Code reorganization

* Code adjustments according to flake8

* Error message adjusted to be more precise

* CHANGELOG update
  • Loading branch information
kostek-pl authored Jun 9, 2021
1 parent abe8e83 commit eb4ad44
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 7 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
## dbt 0.21.0

### Features

### Fixes
- Fix docs generation for cross-db sources in REDSHIFT RA3 node ([#3236](https:/fishtown-analytics/dbt/issues/3236), [#3408](https:/fishtown-analytics/dbt/pull/3408))

### Under the hood

Contributors:
- [@kostek-pl](https:/kostek-pl) ([#3236](https:/fishtown-analytics/dbt/pull/3408))

## dbt 0.20.0 (Release TBD)

## dbt 0.20.0rc1 (June 04, 2021)
Expand Down
11 changes: 6 additions & 5 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,14 @@ def search(
for schema in schemas:
yield information_schema_name, schema

def flatten(self):
def flatten(self, allow_multiple_databases: bool = False):
new = self.__class__()

# make sure we don't have duplicates
seen = {r.database.lower() for r in self if r.database}
if len(seen) > 1:
dbt.exceptions.raise_compiler_error(str(seen))
# make sure we don't have multiple databases if allow_multiple_databases is set to False
if not allow_multiple_databases:
seen = {r.database.lower() for r in self if r.database}
if len(seen) > 1:
dbt.exceptions.raise_compiler_error(str(seen))

for information_schema_name, schema in self.search():
path = {
Expand Down
2 changes: 1 addition & 1 deletion plugins/postgres/dbt/adapters/postgres/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _link_cached_database_relations(self, schemas: Set[str]):
self.cache.add_link(referenced, dependent)

def _get_catalog_schemas(self, manifest):
# postgres/redshift only allow one database (the main one)
# postgres only allow one database (the main one)
schemas = super()._get_catalog_schemas(manifest)
try:
return schemas.flatten()
Expand Down
1 change: 1 addition & 0 deletions plugins/redshift/dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class RedshiftCredentials(PostgresCredentials):
keepalives_idle: int = 240
autocreate: bool = False
db_groups: List[str] = field(default_factory=list)
ra3_node: Optional[bool] = False

@property
def type(self):
Expand Down
31 changes: 30 additions & 1 deletion plugins/redshift/dbt/adapters/redshift/impl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from dataclasses import dataclass
from typing import Optional
from dbt.adapters.base.impl import AdapterConfig
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.base.meta import available
from dbt.adapters.postgres import PostgresAdapter
from dbt.adapters.redshift import RedshiftConnectionManager
from dbt.adapters.redshift import RedshiftColumn
from dbt.adapters.redshift import RedshiftRelation
from dbt.logger import GLOBAL_LOGGER as logger # noqa
import dbt.exceptions


@dataclass
Expand All @@ -16,7 +19,7 @@ class RedshiftConfig(AdapterConfig):
bind: Optional[bool] = None


class RedshiftAdapter(PostgresAdapter):
class RedshiftAdapter(PostgresAdapter, SQLAdapter):
Relation = RedshiftRelation
ConnectionManager = RedshiftConnectionManager
Column = RedshiftColumn
Expand Down Expand Up @@ -57,3 +60,29 @@ def convert_text_type(cls, agate_table, col_idx):
@classmethod
def convert_time_type(cls, agate_table, col_idx):
return "varchar(24)"

@available
def verify_database(self, database):
if database.startswith('"'):
database = database.strip('"')
expected = self.config.credentials.database
ra3_node = self.config.credentials.ra3_node

if database.lower() != expected.lower() and not ra3_node:
raise dbt.exceptions.NotImplementedException(
'Cross-db references allowed only in RA3.* node. ({} vs {})'
.format(database, expected)
)
# return an empty string on success so macros can call this
return ''

def _get_catalog_schemas(self, manifest):
# redshift(besides ra3) only allow one database (the main one)
schemas = super(SQLAdapter, self)._get_catalog_schemas(manifest)
try:
return schemas.flatten(allow_multiple_databases=self.config.credentials.ra3_node)
except dbt.exceptions.RuntimeException as exc:
dbt.exceptions.raise_compiler_error(
'Cross-db references allowed only in {} RA3.* node. Got {}'
.format(self.type(), exc.msg)
)

0 comments on commit eb4ad44

Please sign in to comment.