Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple macro generator from adapters #9149

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
table_from_rows,
Integer,
)
from dbt.clients.jinja import MacroGenerator
from dbt.common.clients.jinja import CallableMacroGenerator
from dbt.contracts.graph.manifest import Manifest, MacroManifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.common.events.functions import fire_event, warn_or_error
Expand Down Expand Up @@ -1115,7 +1115,7 @@ def execute_macro(
)
macro_context.update(context_override)

macro_function = MacroGenerator(macro, macro_context)
macro_function = CallableMacroGenerator(macro, macro_context)

with self.connections.exception_handler(f"macro {macro_name}"):
result = macro_function(**kwargs)
Expand Down Expand Up @@ -1489,7 +1489,7 @@ def get_incremental_strategy_macro(self, model_context, strategy: str):

strategy = strategy.replace("+", "_")
macro_name = f"get_incremental_{strategy}_sql"
# The model_context should have MacroGenerator callable objects for all macros
# The model_context should have callable objects for all macros
if macro_name not in model_context:
raise DbtRuntimeError(
'dbt could not find an incremental strategy macro with the name "{}" in {}'.format(
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from threading import local
from typing import Optional, Callable, Dict, Any

from dbt.clients.jinja import QueryStringGenerator
from dbt.adapters.clients.jinja import QueryStringGenerator

from dbt.context.manifest import generate_query_header_context
from dbt.adapters.contracts.connection import AdapterRequiredConfig, QueryComment
Expand Down
Empty file.
23 changes: 23 additions & 0 deletions core/dbt/adapters/clients/jinja.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Dict, Any
from dbt.common.clients.jinja import BaseMacroGenerator, get_environment


class QueryStringGenerator(BaseMacroGenerator):
def __init__(self, template_str: str, context: Dict[str, Any]) -> None:
super().__init__(context)
self.template_str: str = template_str
env = get_environment()
self.template = env.from_string(
self.template_str,
globals=self.context,
)

def get_name(self) -> str:
return "query_comment_macro"

def get_template(self):
"""Don't use the template cache, we don't have a node"""
return self.template

def __call__(self, connection_name: str, node) -> str:
return str(self.call_macro(connection_name, node))
8 changes: 6 additions & 2 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from dbt.cli.types import Command as CliCommand
from dbt.common import ui
from dbt.common.events import functions
from dbt.common.exceptions import DbtInternalError
from dbt.common.clients import jinja
from dbt.config.profile import read_user_config
from dbt.contracts.project import UserConfig
from dbt.common.exceptions import DbtInternalError
from dbt.deprecations import renamed_env_var
from dbt.common.helper_types import WarnErrorOptions

Expand Down Expand Up @@ -305,11 +306,14 @@ def set_common_global_flags(self):
ui.USE_COLOR = getattr(self, "USE_COLORS")

# Set globals for common.events.functions
#
functions.WARN_ERROR = getattr(self, "WARN_ERROR", False)
if getattr(self, "WARN_ERROR_OPTIONS", None) is not None:
functions.WARN_ERROR_OPTIONS = getattr(self, "WARN_ERROR_OPTIONS")

# Set globals for common.jinja
if getattr(self, "MACRO_DEBUGGING", None) is not None:
jinja.MACRO_DEBUGGING = getattr(self, "MACRO_DEBUGGING")


CommandParams = List[str]

Expand Down
Loading
Loading