Skip to content

Commit

Permalink
Add more template fields to DbtBaseOperator (#786)
Browse files Browse the repository at this point in the history
This partially addresses #754 via allowing for built-in templating
support for the `DbtBaseOperator`.

I also noticed `--full-refresh` was not documented so I added that in.

## Still missing

Manual run pattern is not documented; the fact that these fields are
templated is not documented. I don't really know where in the docs to
put this. The docs are very API-focused more than narrative-based or
suggestive, and Cosmos's maintainers prefer this style of documentation
so it's hard to find a spot for this. It's possible that that's fine and
we just keep this as a feature for more advanced users who dig into
source code to discover for themselves? 🤷

Co-authored-by: Tatiana Al-Chueyr <[email protected]>
  • Loading branch information
dwreeves and tatiana authored Feb 29, 2024
1 parent 1bf78b4 commit 95ad1cb
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 10 deletions.
45 changes: 40 additions & 5 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.strings import to_boolean

from cosmos.dbt.executable import get_system_dbt
from cosmos.log import get_logger
Expand Down Expand Up @@ -61,7 +62,7 @@ class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta):
:param dbt_cmd_global_flags: List of dbt global flags to be passed to the dbt command
"""

template_fields: Sequence[str] = ("env", "vars")
template_fields: Sequence[str] = ("env", "select", "exclude", "selector", "vars", "models")
global_flags = (
"project_dir",
"select",
Expand Down Expand Up @@ -253,6 +254,26 @@ class DbtBuildMixin:
base_cmd = ["build"]
ui_color = "#8194E0"

template_fields: Sequence[str] = ("full_refresh",)

def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)

def add_cmd_flags(self) -> list[str]:
flags = []

if isinstance(self.full_refresh, str):
# Handle template fields when render_template_as_native_obj=False
full_refresh = to_boolean(self.full_refresh)
else:
full_refresh = self.full_refresh

if full_refresh is True:
flags.append("--full-refresh")

return flags


class DbtLSMixin:
"""
Expand All @@ -275,13 +296,20 @@ class DbtSeedMixin:

template_fields: Sequence[str] = ("full_refresh",)

def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None:
def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)

def add_cmd_flags(self) -> list[str]:
flags = []
if self.full_refresh is True:

if isinstance(self.full_refresh, str):
# Handle template fields when render_template_as_native_obj=False
full_refresh = to_boolean(self.full_refresh)
else:
full_refresh = self.full_refresh

if full_refresh is True:
flags.append("--full-refresh")

return flags
Expand All @@ -307,13 +335,20 @@ class DbtRunMixin:

template_fields: Sequence[str] = ("full_refresh",)

def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None:
def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)

def add_cmd_flags(self) -> list[str]:
flags = []
if self.full_refresh is True:

if isinstance(self.full_refresh, str):
# Handle template fields when render_template_as_native_obj=False
full_refresh = to_boolean(self.full_refresh)
else:
full_refresh = self.full_refresh

if full_refresh is True:
flags.append("--full-refresh")

return flags
Expand Down
2 changes: 2 additions & 0 deletions cosmos/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class DbtBuildDockerOperator(DbtBuildMixin, DbtDockerBaseOperator):
Executes a dbt core build command.
"""

template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator]


class DbtLSDockerOperator(DbtLSMixin, DbtDockerBaseOperator):
"""
Expand Down
2 changes: 2 additions & 0 deletions cosmos/operators/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class DbtBuildKubernetesOperator(DbtBuildMixin, DbtKubernetesBaseOperator):
Executes a dbt core build command.
"""

template_fields: Sequence[str] = DbtKubernetesBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator]


class DbtLSKubernetesOperator(DbtLSMixin, DbtKubernetesBaseOperator):
"""
Expand Down
2 changes: 2 additions & 0 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ class DbtBuildLocalOperator(DbtBuildMixin, DbtLocalBaseOperator):
Executes a dbt core build command.
"""

template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator]


class DbtLSLocalOperator(DbtLSMixin, DbtLocalBaseOperator):
"""
Expand Down
2 changes: 1 addition & 1 deletion cosmos/operators/virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def execute(self, context: Context) -> None:
logger.info(output)


class DbtBuildVirtualenvOperator(DbtVirtualenvBaseOperator, DbtBuildLocalOperator):
class DbtBuildVirtualenvOperator(DbtVirtualenvBaseOperator, DbtBuildLocalOperator): # type: ignore[misc]
"""
Executes a dbt core build command within a Python Virtual Environment, that is created before running the dbt command
and deleted just after.
Expand Down
31 changes: 31 additions & 0 deletions docs/configuration/operator-args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dbt-related
- ``quiet``: run ``dbt`` in silent mode, only displaying its error logs.
- ``vars``: (Deprecated since Cosmos 1.3 use ``ProjectConfig.dbt_vars`` instead) Supply variables to the project. This argument overrides variables defined in the ``dbt_project.yml``.
- ``warn_error``: convert ``dbt`` warnings into errors.
- ``full_refresh``: If True, then full refresh the node. This only applies to model and seed nodes.

Airflow-related
...............
Expand Down Expand Up @@ -88,3 +89,33 @@ Sample usage
"skip_exit_code": 1,
}
)
Template fields
---------------

Some of the operator args are `template fields <https://airflow.apache.org/docs/apache-airflow/stable/howto/custom-operator.html#templating>`_ for your convenience.

These template fields can be useful for hooking into Airflow `Params <https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/params.html>`_, or for more advanced customization with `XComs <https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/xcoms.html>`_.

The following operator args support templating, and are accessible both through the ``DbtDag`` and ``DbtTaskGroup`` constructors in addition to being accessible standalone:

- ``env``
- ``vars``
- ``full_refresh`` (for the ``build``, ``seed``, and ``run`` operators since Cosmos 1.4.)

.. note::
Using Jinja templating for ``env`` and ``vars`` may cause problems when using ``LoadMode.DBT_LS`` to render your DAG.

The following template fields are only selectable when using the operators in a standalone context (starting in Cosmos 1.4):

- ``select``
- ``exclude``
- ``selector``
- ``models``

Since Airflow resolves template fields during Airflow DAG execution and not DAG parsing, the args above cannot be templated via ``DbtDag`` and ``DbtTaskGroup`` because both need to select dbt nodes during DAG parsing.

Additionally, the SQL for compiled dbt models is stored in the template fields, which is viewable in the Airflow UI for each task run.
This is provided for telemetry on task execution, and is not an operator arg.
For more information about this, see the `Compiled SQL <compiled-sql.html>`_ docs.
6 changes: 4 additions & 2 deletions tests/operators/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def test_dbt_mixin_base_cmd(dbt_command, dbt_operator_class):
assert [dbt_command] == dbt_operator_class.base_cmd


@pytest.mark.parametrize("dbt_operator_class", [DbtSeedMixin, DbtRunMixin])
@pytest.mark.parametrize("full_refresh, expected_flags", [(True, ["--full-refresh"]), (False, [])])
@pytest.mark.parametrize("dbt_operator_class", [DbtSeedMixin, DbtRunMixin, DbtBuildMixin])
@pytest.mark.parametrize(
"full_refresh, expected_flags", [("True", ["--full-refresh"]), (True, ["--full-refresh"]), (False, [])]
)
def test_dbt_mixin_add_cmd_flags_full_refresh(full_refresh, expected_flags, dbt_operator_class):
dbt_mixin = dbt_operator_class(full_refresh=full_refresh)
flags = dbt_mixin.add_cmd_flags()
Expand Down
15 changes: 13 additions & 2 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def test_store_compiled_sql() -> None:
"operator_class,kwargs,expected_call_kwargs",
[
(DbtSeedLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(DbtBuildLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(DbtRunLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(
DbtTestLocalOperator,
Expand Down Expand Up @@ -650,8 +651,18 @@ def test_calculate_openlineage_events_completes_openlineage_errors(mock_processo
@pytest.mark.parametrize(
"operator_class,expected_template",
[
(DbtSeedLocalOperator, ("env", "vars", "compiled_sql", "full_refresh")),
(DbtRunLocalOperator, ("env", "vars", "compiled_sql", "full_refresh")),
(
DbtSeedLocalOperator,
("env", "select", "exclude", "selector", "vars", "models", "compiled_sql", "full_refresh"),
),
(
DbtRunLocalOperator,
("env", "select", "exclude", "selector", "vars", "models", "compiled_sql", "full_refresh"),
),
(
DbtBuildLocalOperator,
("env", "select", "exclude", "selector", "vars", "models", "compiled_sql", "full_refresh"),
),
],
)
def test_dbt_base_operator_template_fields(operator_class, expected_template):
Expand Down

0 comments on commit 95ad1cb

Please sign in to comment.