Skip to content

Commit

Permalink
Set Flags from UserConfig (#6266)
Browse files Browse the repository at this point in the history
flags with user config, flags.WHICH from invoked_subcommand if available
  • Loading branch information
MichelleArk authored and iknox-fa committed Dec 4, 2022
1 parent c2d2255 commit ae77041
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 9 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20221129-183239.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Click CLI Flags work with UserConfig
time: 2022-11-29T18:32:39.068035-05:00
custom:
Author: michelleark
Issue: "6327"
PR: "6266"
42 changes: 33 additions & 9 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
from importlib import import_module
from multiprocessing import get_context
from pprint import pformat as pf
from typing import Set

from click import Context, get_current_context
from click.core import ParameterSource

from dbt.config.profile import read_user_config
from dbt.contracts.project import UserConfig

if os.name != "nt":
# https://bugs.python.org/issue41567
Expand All @@ -15,12 +20,12 @@

@dataclass(frozen=True)
class Flags:
def __init__(self, ctx: Context = None) -> None:
def __init__(self, ctx: Context = None, user_config: UserConfig = None) -> None:

if ctx is None:
ctx = get_current_context()

def assign_params(ctx):
def assign_params(ctx, params_assigned_from_default):
"""Recursively adds all click params to flag object"""
for param_name, param_value in ctx.params.items():
# N.B. You have to use the base MRO method (object.__setattr__) to set attributes
Expand All @@ -29,29 +34,48 @@ def assign_params(ctx):
if hasattr(self, param_name):
raise Exception(f"Duplicate flag names found in click command: {param_name}")
object.__setattr__(self, param_name.upper(), param_value)
if ctx.get_parameter_source(param_name) == ParameterSource.DEFAULT:
params_assigned_from_default.add(param_name)
if ctx.parent:
assign_params(ctx.parent)
assign_params(ctx.parent, params_assigned_from_default)

assign_params(ctx)
params_assigned_from_default = set() # type: Set[str]
assign_params(ctx, params_assigned_from_default)

# Get the invoked command flags
if hasattr(ctx, "invoked_subcommand") and ctx.invoked_subcommand is not None:
invoked_subcommand = getattr(import_module("dbt.cli.main"), ctx.invoked_subcommand)
invoked_subcommand_name = (
ctx.invoked_subcommand if hasattr(ctx, "invoked_subcommand") else None
)
if invoked_subcommand_name is not None:
invoked_subcommand = getattr(import_module("dbt.cli.main"), invoked_subcommand_name)
invoked_subcommand.allow_extra_args = True
invoked_subcommand.ignore_unknown_options = True
invoked_subcommand_ctx = invoked_subcommand.make_context(None, sys.argv)
assign_params(invoked_subcommand_ctx)
assign_params(invoked_subcommand_ctx, params_assigned_from_default)

if not user_config:
profiles_dir = getattr(self, "PROFILES_DIR", None)
user_config = read_user_config(profiles_dir) if profiles_dir else None

# Overwrite default assignments with user config if available
if user_config:
for param_assigned_from_default in params_assigned_from_default:
user_config_param_value = getattr(user_config, param_assigned_from_default, None)
if user_config_param_value is not None:
object.__setattr__(
self, param_assigned_from_default.upper(), user_config_param_value
)

# Hard coded flags
object.__setattr__(self, "WHICH", ctx.info_name)
object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name)
object.__setattr__(self, "MP_CONTEXT", get_context("spawn"))

# Support console DO NOT TRACK initiave
object.__setattr__(
self,
"ANONYMOUS_USAGE_STATS",
False
if os.getenv("DO_NOT_TRACK", "").lower() in (1, "t", "true", "y", "yes")
if os.getenv("DO_NOT_TRACK", "").lower() in ("1", "t", "true", "y", "yes")
else True,
)

Expand Down
82 changes: 82 additions & 0 deletions tests/unit/test_cli_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest

import click
from multiprocessing import get_context
from typing import List

from dbt.cli.main import cli
from dbt.contracts.project import UserConfig
from dbt.cli.flags import Flags


class TestFlags:
def make_dbt_context(self, context_name: str, args: List[str]) -> click.Context:
ctx = cli.make_context(context_name, args)
return ctx

@pytest.fixture(scope="class")
def run_context(self) -> click.Context:
return self.make_dbt_context("run", ["run"])

def test_which(self, run_context):
flags = Flags(run_context)
assert flags.WHICH == "run"

def test_mp_context(self, run_context):
flags = Flags(run_context)
assert flags.MP_CONTEXT == get_context("spawn")

@pytest.mark.parametrize('param', cli.params)
def test_cli_group_flags_from_params(self, run_context, param):
flags = Flags(run_context)
assert hasattr(flags, param.name.upper())
assert getattr(flags, param.name.upper()) == run_context.params[param.name.lower()]

@pytest.mark.parametrize('do_not_track,expected_anonymous_usage_stats', [
("1", False),
("t", False),
("true", False),
("y", False),
("yes", False),
("false", True),
("anything", True),
("2", True),
])
def test_anonymous_usage_state(self, monkeypatch, run_context, do_not_track, expected_anonymous_usage_stats):
monkeypatch.setenv("DO_NOT_TRACK", do_not_track)

flags = Flags(run_context)
assert flags.ANONYMOUS_USAGE_STATS == expected_anonymous_usage_stats

def test_empty_user_config_uses_default(self, run_context):
user_config = UserConfig()

flags = Flags(run_context, user_config)
assert flags.USE_COLORS == run_context.params['use_colors']

def test_none_user_config_uses_default(self, run_context):
flags = Flags(run_context, None)
assert flags.USE_COLORS == run_context.params['use_colors']

def test_prefer_user_config_to_default(self, run_context):
user_config = UserConfig(use_colors=False)
# ensure default value is not the same as user config
assert run_context.params['use_colors'] is not user_config.use_colors

flags = Flags(run_context, user_config)
assert flags.USE_COLORS == user_config.use_colors

def test_prefer_param_value_to_user_config(self):
user_config = UserConfig(use_colors=False)
context = self.make_dbt_context("run", ["--use-colors", "True", "run"])

flags = Flags(context, user_config)
assert flags.USE_COLORS

def test_prefer_env_to_user_config(self, monkeypatch):
user_config = UserConfig(use_colors=False)
monkeypatch.setenv("DBT_USE_COLORS", "True")
context = self.make_dbt_context("run", ["run"])

flags = Flags(context, user_config)
assert flags.USE_COLORS

0 comments on commit ae77041

Please sign in to comment.