diff --git a/.changes/unreleased/Features-20221129-183239.yaml b/.changes/unreleased/Features-20221129-183239.yaml new file mode 100644 index 00000000000..22a92ea36a7 --- /dev/null +++ b/.changes/unreleased/Features-20221129-183239.yaml @@ -0,0 +1,7 @@ +kind: Features +body: Click CLI Flags work with UserConfig +time: 2022-11-29T18:32:39.068035-05:00 +custom: + Author: michelleark + Issue: "6327" + PR: "6266" diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index 873cdfdfa40..93717875632 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -5,8 +5,13 @@ from importlib import import_module from multiprocessing import get_context from pprint import pformat as pf +from typing import Set from click import Context, get_current_context +from click.core import ParameterSource + +from dbt.config.profile import read_user_config +from dbt.contracts.project import UserConfig if os.name != "nt": # https://bugs.python.org/issue41567 @@ -15,12 +20,12 @@ @dataclass(frozen=True) class Flags: - def __init__(self, ctx: Context = None) -> None: + def __init__(self, ctx: Context = None, user_config: UserConfig = None) -> None: if ctx is None: ctx = get_current_context() - def assign_params(ctx): + def assign_params(ctx, params_assigned_from_default): """Recursively adds all click params to flag object""" for param_name, param_value in ctx.params.items(): # N.B. You have to use the base MRO method (object.__setattr__) to set attributes @@ -29,21 +34,40 @@ def assign_params(ctx): if hasattr(self, param_name): raise Exception(f"Duplicate flag names found in click command: {param_name}") object.__setattr__(self, param_name.upper(), param_value) + if ctx.get_parameter_source(param_name) == ParameterSource.DEFAULT: + params_assigned_from_default.add(param_name) if ctx.parent: - assign_params(ctx.parent) + assign_params(ctx.parent, params_assigned_from_default) - assign_params(ctx) + params_assigned_from_default = set() # type: Set[str] + assign_params(ctx, params_assigned_from_default) # Get the invoked command flags - if hasattr(ctx, "invoked_subcommand") and ctx.invoked_subcommand is not None: - invoked_subcommand = getattr(import_module("dbt.cli.main"), ctx.invoked_subcommand) + invoked_subcommand_name = ( + ctx.invoked_subcommand if hasattr(ctx, "invoked_subcommand") else None + ) + if invoked_subcommand_name is not None: + invoked_subcommand = getattr(import_module("dbt.cli.main"), invoked_subcommand_name) invoked_subcommand.allow_extra_args = True invoked_subcommand.ignore_unknown_options = True invoked_subcommand_ctx = invoked_subcommand.make_context(None, sys.argv) - assign_params(invoked_subcommand_ctx) + assign_params(invoked_subcommand_ctx, params_assigned_from_default) + + if not user_config: + profiles_dir = getattr(self, "PROFILES_DIR", None) + user_config = read_user_config(profiles_dir) if profiles_dir else None + + # Overwrite default assignments with user config if available + if user_config: + for param_assigned_from_default in params_assigned_from_default: + user_config_param_value = getattr(user_config, param_assigned_from_default, None) + if user_config_param_value is not None: + object.__setattr__( + self, param_assigned_from_default.upper(), user_config_param_value + ) # Hard coded flags - object.__setattr__(self, "WHICH", ctx.info_name) + object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name) object.__setattr__(self, "MP_CONTEXT", get_context("spawn")) # Support console DO NOT TRACK initiave @@ -51,7 +75,7 @@ def assign_params(ctx): self, "ANONYMOUS_USAGE_STATS", False - if os.getenv("DO_NOT_TRACK", "").lower() in (1, "t", "true", "y", "yes") + if os.getenv("DO_NOT_TRACK", "").lower() in ("1", "t", "true", "y", "yes") else True, ) diff --git a/tests/unit/test_cli_flags.py b/tests/unit/test_cli_flags.py new file mode 100644 index 00000000000..d3dedac2390 --- /dev/null +++ b/tests/unit/test_cli_flags.py @@ -0,0 +1,82 @@ +import pytest + +import click +from multiprocessing import get_context +from typing import List + +from dbt.cli.main import cli +from dbt.contracts.project import UserConfig +from dbt.cli.flags import Flags + + +class TestFlags: + def make_dbt_context(self, context_name: str, args: List[str]) -> click.Context: + ctx = cli.make_context(context_name, args) + return ctx + + @pytest.fixture(scope="class") + def run_context(self) -> click.Context: + return self.make_dbt_context("run", ["run"]) + + def test_which(self, run_context): + flags = Flags(run_context) + assert flags.WHICH == "run" + + def test_mp_context(self, run_context): + flags = Flags(run_context) + assert flags.MP_CONTEXT == get_context("spawn") + + @pytest.mark.parametrize('param', cli.params) + def test_cli_group_flags_from_params(self, run_context, param): + flags = Flags(run_context) + assert hasattr(flags, param.name.upper()) + assert getattr(flags, param.name.upper()) == run_context.params[param.name.lower()] + + @pytest.mark.parametrize('do_not_track,expected_anonymous_usage_stats', [ + ("1", False), + ("t", False), + ("true", False), + ("y", False), + ("yes", False), + ("false", True), + ("anything", True), + ("2", True), + ]) + def test_anonymous_usage_state(self, monkeypatch, run_context, do_not_track, expected_anonymous_usage_stats): + monkeypatch.setenv("DO_NOT_TRACK", do_not_track) + + flags = Flags(run_context) + assert flags.ANONYMOUS_USAGE_STATS == expected_anonymous_usage_stats + + def test_empty_user_config_uses_default(self, run_context): + user_config = UserConfig() + + flags = Flags(run_context, user_config) + assert flags.USE_COLORS == run_context.params['use_colors'] + + def test_none_user_config_uses_default(self, run_context): + flags = Flags(run_context, None) + assert flags.USE_COLORS == run_context.params['use_colors'] + + def test_prefer_user_config_to_default(self, run_context): + user_config = UserConfig(use_colors=False) + # ensure default value is not the same as user config + assert run_context.params['use_colors'] is not user_config.use_colors + + flags = Flags(run_context, user_config) + assert flags.USE_COLORS == user_config.use_colors + + def test_prefer_param_value_to_user_config(self): + user_config = UserConfig(use_colors=False) + context = self.make_dbt_context("run", ["--use-colors", "True", "run"]) + + flags = Flags(context, user_config) + assert flags.USE_COLORS + + def test_prefer_env_to_user_config(self, monkeypatch): + user_config = UserConfig(use_colors=False) + monkeypatch.setenv("DBT_USE_COLORS", "True") + context = self.make_dbt_context("run", ["run"]) + + flags = Flags(context, user_config) + assert flags.USE_COLORS