Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set Flags from UserConfig #6266

Merged
merged 6 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20221129-183239.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Click CLI Flags work with UserConfig
time: 2022-11-29T18:32:39.068035-05:00
custom:
Author: michelleark
Issue: "6327"
PR: "6266"
42 changes: 33 additions & 9 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
from importlib import import_module
from multiprocessing import get_context
from pprint import pformat as pf
from typing import Set

from click import Context, get_current_context
from click.core import ParameterSource

from dbt.config.profile import read_user_config
from dbt.contracts.project import UserConfig

if os.name != "nt":
# https://bugs.python.org/issue41567
Expand All @@ -15,12 +20,12 @@

@dataclass(frozen=True)
class Flags:
def __init__(self, ctx: Context = None) -> None:
def __init__(self, ctx: Context = None, user_config: UserConfig = None) -> None:

if ctx is None:
ctx = get_current_context()

def assign_params(ctx):
def assign_params(ctx, params_assigned_from_default):
"""Recursively adds all click params to flag object"""
for param_name, param_value in ctx.params.items():
# N.B. You have to use the base MRO method (object.__setattr__) to set attributes
Expand All @@ -29,29 +34,48 @@ def assign_params(ctx):
if hasattr(self, param_name):
raise Exception(f"Duplicate flag names found in click command: {param_name}")
object.__setattr__(self, param_name.upper(), param_value)
if ctx.get_parameter_source(param_name) == ParameterSource.DEFAULT:
params_assigned_from_default.add(param_name)
if ctx.parent:
assign_params(ctx.parent)
assign_params(ctx.parent, params_assigned_from_default)

assign_params(ctx)
params_assigned_from_default = set() # type: Set[str]
assign_params(ctx, params_assigned_from_default)

# Get the invoked command flags
if hasattr(ctx, "invoked_subcommand") and ctx.invoked_subcommand is not None:
invoked_subcommand = getattr(import_module("dbt.cli.main"), ctx.invoked_subcommand)
invoked_subcommand_name = (
ctx.invoked_subcommand if hasattr(ctx, "invoked_subcommand") else None
)
if invoked_subcommand_name is not None:
invoked_subcommand = getattr(import_module("dbt.cli.main"), invoked_subcommand_name)
invoked_subcommand.allow_extra_args = True
invoked_subcommand.ignore_unknown_options = True
invoked_subcommand_ctx = invoked_subcommand.make_context(None, sys.argv)
assign_params(invoked_subcommand_ctx)
assign_params(invoked_subcommand_ctx, params_assigned_from_default)

if not user_config:
profiles_dir = getattr(self, "PROFILES_DIR", None)
user_config = read_user_config(profiles_dir) if profiles_dir else None

# Overwrite default assignments with user config if available
if user_config:
for param_assigned_from_default in params_assigned_from_default:
user_config_param_value = getattr(user_config, param_assigned_from_default, None)
if user_config_param_value is not None:
object.__setattr__(
self, param_assigned_from_default.upper(), user_config_param_value
)

# Hard coded flags
object.__setattr__(self, "WHICH", ctx.info_name)
object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name)
object.__setattr__(self, "MP_CONTEXT", get_context("spawn"))

# Support console DO NOT TRACK initiave
object.__setattr__(
self,
"ANONYMOUS_USAGE_STATS",
False
if os.getenv("DO_NOT_TRACK", "").lower() in (1, "t", "true", "y", "yes")
if os.getenv("DO_NOT_TRACK", "").lower() in ("1", "t", "true", "y", "yes")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! env vars != ints. Ever.

else True,
)

Expand Down
82 changes: 82 additions & 0 deletions tests/unit/test_cli_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest

import click
from multiprocessing import get_context
from typing import List

from dbt.cli.main import cli
from dbt.contracts.project import UserConfig
from dbt.cli.flags import Flags


class TestFlags:
def make_dbt_context(self, context_name: str, args: List[str]) -> click.Context:
ctx = cli.make_context(context_name, args)
return ctx

@pytest.fixture(scope="class")
def run_context(self) -> click.Context:
return self.make_dbt_context("run", ["run"])

def test_which(self, run_context):
flags = Flags(run_context)
assert flags.WHICH == "run"

def test_mp_context(self, run_context):
flags = Flags(run_context)
assert flags.MP_CONTEXT == get_context("spawn")

@pytest.mark.parametrize('param', cli.params)
def test_cli_group_flags_from_params(self, run_context, param):
flags = Flags(run_context)
assert hasattr(flags, param.name.upper())
assert getattr(flags, param.name.upper()) == run_context.params[param.name.lower()]

@pytest.mark.parametrize('do_not_track,expected_anonymous_usage_stats', [
("1", False),
("t", False),
("true", False),
("y", False),
("yes", False),
("false", True),
("anything", True),
("2", True),
])
def test_anonymous_usage_state(self, monkeypatch, run_context, do_not_track, expected_anonymous_usage_stats):
monkeypatch.setenv("DO_NOT_TRACK", do_not_track)

flags = Flags(run_context)
assert flags.ANONYMOUS_USAGE_STATS == expected_anonymous_usage_stats

def test_empty_user_config_uses_default(self, run_context):
user_config = UserConfig()

flags = Flags(run_context, user_config)
assert flags.USE_COLORS == run_context.params['use_colors']

def test_none_user_config_uses_default(self, run_context):
flags = Flags(run_context, None)
assert flags.USE_COLORS == run_context.params['use_colors']

def test_prefer_user_config_to_default(self, run_context):
user_config = UserConfig(use_colors=False)
# ensure default value is not the same as user config
assert run_context.params['use_colors'] is not user_config.use_colors

flags = Flags(run_context, user_config)
assert flags.USE_COLORS == user_config.use_colors

def test_prefer_param_value_to_user_config(self):
user_config = UserConfig(use_colors=False)
context = self.make_dbt_context("run", ["--use-colors", "True", "run"])

flags = Flags(context, user_config)
assert flags.USE_COLORS

def test_prefer_env_to_user_config(self, monkeypatch):
user_config = UserConfig(use_colors=False)
monkeypatch.setenv("DBT_USE_COLORS", "True")
context = self.make_dbt_context("run", ["run"])

flags = Flags(context, user_config)
assert flags.USE_COLORS