diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index e94cae05f..94d5d4a8c 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -8,6 +8,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.utils.context import Context from airflow.utils.operator_helpers import context_to_airflow_vars +from airflow.utils.strings import to_boolean from cosmos.dbt.executable import get_system_dbt from cosmos.log import get_logger @@ -61,7 +62,7 @@ class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta): :param dbt_cmd_global_flags: List of dbt global flags to be passed to the dbt command """ - template_fields: Sequence[str] = ("env", "vars") + template_fields: Sequence[str] = ("env", "select", "exclude", "selector", "vars", "models") global_flags = ( "project_dir", "select", @@ -253,6 +254,26 @@ class DbtBuildMixin: base_cmd = ["build"] ui_color = "#8194E0" + template_fields: Sequence[str] = ("full_refresh",) + + def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None: + self.full_refresh = full_refresh + super().__init__(**kwargs) + + def add_cmd_flags(self) -> list[str]: + flags = [] + + if isinstance(self.full_refresh, str): + # Handle template fields when render_template_as_native_obj=False + full_refresh = to_boolean(self.full_refresh) + else: + full_refresh = self.full_refresh + + if full_refresh is True: + flags.append("--full-refresh") + + return flags + class DbtLSMixin: """ @@ -275,13 +296,20 @@ class DbtSeedMixin: template_fields: Sequence[str] = ("full_refresh",) - def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: + def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None: self.full_refresh = full_refresh super().__init__(**kwargs) def add_cmd_flags(self) -> list[str]: flags = [] - if self.full_refresh is True: + + if isinstance(self.full_refresh, str): + # Handle template fields when render_template_as_native_obj=False + full_refresh = to_boolean(self.full_refresh) + else: + full_refresh = self.full_refresh + + if full_refresh is True: flags.append("--full-refresh") return flags @@ -307,13 +335,20 @@ class DbtRunMixin: template_fields: Sequence[str] = ("full_refresh",) - def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: + def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None: self.full_refresh = full_refresh super().__init__(**kwargs) def add_cmd_flags(self) -> list[str]: flags = [] - if self.full_refresh is True: + + if isinstance(self.full_refresh, str): + # Handle template fields when render_template_as_native_obj=False + full_refresh = to_boolean(self.full_refresh) + else: + full_refresh = self.full_refresh + + if full_refresh is True: flags.append("--full-refresh") return flags diff --git a/cosmos/operators/docker.py b/cosmos/operators/docker.py index 5be03fad7..532de380e 100644 --- a/cosmos/operators/docker.py +++ b/cosmos/operators/docker.py @@ -69,6 +69,8 @@ class DbtBuildDockerOperator(DbtBuildMixin, DbtDockerBaseOperator): Executes a dbt core build command. """ + template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator] + class DbtLSDockerOperator(DbtLSMixin, DbtDockerBaseOperator): """ diff --git a/cosmos/operators/kubernetes.py b/cosmos/operators/kubernetes.py index e314b7c43..14bcbcb84 100644 --- a/cosmos/operators/kubernetes.py +++ b/cosmos/operators/kubernetes.py @@ -102,6 +102,8 @@ class DbtBuildKubernetesOperator(DbtBuildMixin, DbtKubernetesBaseOperator): Executes a dbt core build command. """ + template_fields: Sequence[str] = DbtKubernetesBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator] + class DbtLSKubernetesOperator(DbtLSMixin, DbtKubernetesBaseOperator): """ diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 7e71fcba6..e6a09748f 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -468,6 +468,8 @@ class DbtBuildLocalOperator(DbtBuildMixin, DbtLocalBaseOperator): Executes a dbt core build command. """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator] + class DbtLSLocalOperator(DbtLSMixin, DbtLocalBaseOperator): """ diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index 2261f764c..b17772b88 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -104,7 +104,7 @@ def execute(self, context: Context) -> None: logger.info(output) -class DbtBuildVirtualenvOperator(DbtVirtualenvBaseOperator, DbtBuildLocalOperator): +class DbtBuildVirtualenvOperator(DbtVirtualenvBaseOperator, DbtBuildLocalOperator): # type: ignore[misc] """ Executes a dbt core build command within a Python Virtual Environment, that is created before running the dbt command and deleted just after. diff --git a/docs/configuration/operator-args.rst b/docs/configuration/operator-args.rst index 5ddbe6565..4e6a40b7f 100644 --- a/docs/configuration/operator-args.rst +++ b/docs/configuration/operator-args.rst @@ -54,6 +54,7 @@ dbt-related - ``quiet``: run ``dbt`` in silent mode, only displaying its error logs. - ``vars``: (Deprecated since Cosmos 1.3 use ``ProjectConfig.dbt_vars`` instead) Supply variables to the project. This argument overrides variables defined in the ``dbt_project.yml``. - ``warn_error``: convert ``dbt`` warnings into errors. +- ``full_refresh``: If True, then full refresh the node. This only applies to model and seed nodes. Airflow-related ............... @@ -88,3 +89,33 @@ Sample usage "skip_exit_code": 1, } ) + + +Template fields +--------------- + +Some of the operator args are `template fields `_ for your convenience. + +These template fields can be useful for hooking into Airflow `Params `_, or for more advanced customization with `XComs `_. + +The following operator args support templating, and are accessible both through the ``DbtDag`` and ``DbtTaskGroup`` constructors in addition to being accessible standalone: + +- ``env`` +- ``vars`` +- ``full_refresh`` (for the ``build``, ``seed``, and ``run`` operators since Cosmos 1.4.) + +.. note:: + Using Jinja templating for ``env`` and ``vars`` may cause problems when using ``LoadMode.DBT_LS`` to render your DAG. + +The following template fields are only selectable when using the operators in a standalone context (starting in Cosmos 1.4): + +- ``select`` +- ``exclude`` +- ``selector`` +- ``models`` + +Since Airflow resolves template fields during Airflow DAG execution and not DAG parsing, the args above cannot be templated via ``DbtDag`` and ``DbtTaskGroup`` because both need to select dbt nodes during DAG parsing. + +Additionally, the SQL for compiled dbt models is stored in the template fields, which is viewable in the Airflow UI for each task run. +This is provided for telemetry on task execution, and is not an operator arg. +For more information about this, see the `Compiled SQL `_ docs. diff --git a/tests/operators/test_base.py b/tests/operators/test_base.py index edaf8a845..5761d66aa 100644 --- a/tests/operators/test_base.py +++ b/tests/operators/test_base.py @@ -51,8 +51,10 @@ def test_dbt_mixin_base_cmd(dbt_command, dbt_operator_class): assert [dbt_command] == dbt_operator_class.base_cmd -@pytest.mark.parametrize("dbt_operator_class", [DbtSeedMixin, DbtRunMixin]) -@pytest.mark.parametrize("full_refresh, expected_flags", [(True, ["--full-refresh"]), (False, [])]) +@pytest.mark.parametrize("dbt_operator_class", [DbtSeedMixin, DbtRunMixin, DbtBuildMixin]) +@pytest.mark.parametrize( + "full_refresh, expected_flags", [("True", ["--full-refresh"]), (True, ["--full-refresh"]), (False, [])] +) def test_dbt_mixin_add_cmd_flags_full_refresh(full_refresh, expected_flags, dbt_operator_class): dbt_mixin = dbt_operator_class(full_refresh=full_refresh) flags = dbt_mixin.add_cmd_flags() diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 05cca4bcc..c054cf4d1 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -567,6 +567,7 @@ def test_store_compiled_sql() -> None: "operator_class,kwargs,expected_call_kwargs", [ (DbtSeedLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}), + (DbtBuildLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}), (DbtRunLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}), ( DbtTestLocalOperator, @@ -650,8 +651,18 @@ def test_calculate_openlineage_events_completes_openlineage_errors(mock_processo @pytest.mark.parametrize( "operator_class,expected_template", [ - (DbtSeedLocalOperator, ("env", "vars", "compiled_sql", "full_refresh")), - (DbtRunLocalOperator, ("env", "vars", "compiled_sql", "full_refresh")), + ( + DbtSeedLocalOperator, + ("env", "select", "exclude", "selector", "vars", "models", "compiled_sql", "full_refresh"), + ), + ( + DbtRunLocalOperator, + ("env", "select", "exclude", "selector", "vars", "models", "compiled_sql", "full_refresh"), + ), + ( + DbtBuildLocalOperator, + ("env", "select", "exclude", "selector", "vars", "models", "compiled_sql", "full_refresh"), + ), ], ) def test_dbt_base_operator_template_fields(operator_class, expected_template):