Skip to content

Commit

Permalink
Update current tests and add new test with proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
b-per committed Oct 7, 2024
1 parent 0b11c34 commit c5afd71
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 17 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def get_conn(self, *args, **kwargs) -> Session:
def _paginate(
self, endpoint: str, payload: dict[str, Any] | None = None, proxies: dict[str, str] | None = None
) -> list[Response]:
extra_options = {"proxies": proxies} if proxies is not None else {}
extra_options = {"proxies": proxies} if proxies is not None else None
response = self.run(endpoint=endpoint, data=payload, extra_options=extra_options)
resp_json = response.json()
limit = resp_json["extra"]["filters"]["limit"]
Expand Down Expand Up @@ -332,7 +332,7 @@ def _run_and_get_response(
self.method = method
full_endpoint = f"api/{api_version}/accounts/{endpoint}" if endpoint else None
proxies = self._get_proxies(self.connection)
extra_options = {"proxies": proxies} if proxies is not None else {}
extra_options = {"proxies": proxies} if proxies is not None else None

if paginate:
if isinstance(payload, str):
Expand Down
95 changes: 80 additions & 15 deletions tests/providers/dbt/cloud/hooks/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@
ACCOUNT_ID_CONN = "account_id_conn"
NO_ACCOUNT_ID_CONN = "no_account_id_conn"
SINGLE_TENANT_CONN = "single_tenant_conn"
PROXY_CONN = "proxy_conn"
DEFAULT_ACCOUNT_ID = 11111
ACCOUNT_ID = 22222
SINGLE_TENANT_DOMAIN = "single.tenant.getdbt.com"
EXTRA_PROXIES = {"proxies": {"https": "http://myproxy:1234"}}
TOKEN = "token"
PROJECT_ID = 33333
JOB_ID = 4444
Expand Down Expand Up @@ -136,9 +138,20 @@ def setup_class(self):
host=SINGLE_TENANT_DOMAIN,
)

# Connection with a proxy set in extra parameters
proxy_conn = Connection(
conn_id=PROXY_CONN,
conn_type=DbtCloudHook.conn_type,
login=DEFAULT_ACCOUNT_ID,
password=TOKEN,
host=SINGLE_TENANT_DOMAIN,
extra=EXTRA_PROXIES,
)

db.merge_conn(account_id_conn)
db.merge_conn(no_account_id_conn)
db.merge_conn(host_conn)
db.merge_conn(proxy_conn)

@pytest.mark.parametrize(
argnames="conn_id, url",
Expand Down Expand Up @@ -196,7 +209,7 @@ def test_list_accounts(self, mock_http_run, mock_paginate, conn_id, account_id):
hook.list_accounts()

assert hook.method == "GET"
hook.run.assert_called_once_with(endpoint=None, data=None)
hook.run.assert_called_once_with(endpoint=None, data=None, extra_options=None)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -213,7 +226,9 @@ def test_get_account(self, mock_http_run, mock_paginate, conn_id, account_id):
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/", data=None)
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/", data=None, extra_options=None
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -232,7 +247,7 @@ def test_list_projects(self, mock_http_run, mock_paginate, conn_id, account_id):
_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_not_called()
hook._paginate.assert_called_once_with(
endpoint=f"api/v3/accounts/{_account_id}/projects/", payload=None
endpoint=f"api/v3/accounts/{_account_id}/projects/", payload=None, proxies=None
)

@pytest.mark.parametrize(
Expand All @@ -250,7 +265,7 @@ def test_get_project(self, mock_http_run, mock_paginate, conn_id, account_id):

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/", data=None
endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/", data=None, extra_options=None
)
hook._paginate.assert_not_called()

Expand All @@ -269,7 +284,9 @@ def test_list_jobs(self, mock_http_run, mock_paginate, conn_id, account_id):

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook._paginate.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/jobs/", payload={"order_by": None, "project_id": None}
endpoint=f"api/v2/accounts/{_account_id}/jobs/",
payload={"order_by": None, "project_id": None},
proxies=None,
)
hook.run.assert_not_called()

Expand All @@ -290,6 +307,7 @@ def test_list_jobs_with_payload(self, mock_http_run, mock_paginate, conn_id, acc
hook._paginate.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/jobs/",
payload={"order_by": "-id", "project_id": PROJECT_ID},
proxies=None,
)
hook.run.assert_not_called()

Expand All @@ -307,7 +325,9 @@ def test_get_job(self, mock_http_run, mock_paginate, conn_id, account_id):
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}", data=None)
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}", data=None, extra_options=None
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -328,6 +348,7 @@ def test_trigger_job_run(self, mock_http_run, mock_paginate, conn_id, account_id
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
data=json.dumps({"cause": cause, "steps_override": None, "schema_override": None}),
extra_options=None,
)
hook._paginate.assert_not_called()

Expand Down Expand Up @@ -359,6 +380,7 @@ def test_trigger_job_run_with_overrides(self, mock_http_run, mock_paginate, conn
data=json.dumps(
{"cause": cause, "steps_override": steps_override, "schema_override": schema_override}
),
extra_options=None,
)
hook._paginate.assert_not_called()

Expand Down Expand Up @@ -393,6 +415,7 @@ def test_trigger_job_run_with_additional_run_configs(
"generate_docs_override": False,
}
),
extra_options=None,
)
hook._paginate.assert_not_called()

Expand Down Expand Up @@ -422,6 +445,7 @@ def test_trigger_job_run_with_longer_cause(self, mock_http_run, mock_paginate, c
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
data=json.dumps({"cause": expected_cause, "steps_override": None, "schema_override": None}),
extra_options=None,
)
hook._paginate.assert_not_called()

Expand Down Expand Up @@ -467,7 +491,9 @@ def test_trigger_job_run_with_retry_from_failure(
hook._paginate.assert_not_called()
if should_use_rerun:
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/rerun/", data=None
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/rerun/",
data=None,
extra_options=None,
)
else:
hook.run.assert_called_once_with(
Expand All @@ -479,8 +505,31 @@ def test_trigger_job_run_with_retry_from_failure(
"schema_override": None,
}
),
extra_options=None,
)

@pytest.mark.parametrize(
argnames="conn_id, account_id",
argvalues=[(PROXY_CONN, ACCOUNT_ID)],
ids=["proxy_connection"],
)
@patch.object(DbtCloudHook, "run")
@patch.object(DbtCloudHook, "_paginate")
def test_trigger_job_run_with_proxy(self, mock_http_run, mock_paginate, conn_id, account_id):
hook = DbtCloudHook(conn_id)
cause = ""
hook.trigger_job_run(job_id=JOB_ID, cause=cause, account_id=account_id)

assert hook.method == "POST"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
data=json.dumps({"cause": cause, "steps_override": None, "schema_override": None}),
extra_options={"proxies": {"https": "http://myproxy:1234"}},
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
argnames="conn_id, account_id",
argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
Expand All @@ -503,6 +552,7 @@ def test_list_job_runs(self, mock_http_run, mock_paginate, conn_id, account_id):
"job_definition_id": None,
"order_by": None,
},
proxies=None,
)

@pytest.mark.parametrize(
Expand All @@ -529,6 +579,7 @@ def test_list_job_runs_with_payload(self, mock_http_run, mock_paginate, conn_id,
"job_definition_id": JOB_ID,
"order_by": "id",
},
proxies=None,
)

@pytest.mark.parametrize(
Expand All @@ -544,7 +595,9 @@ def test_get_job_runs(self, mock_http_run, conn_id, account_id):
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/runs/", data=None)
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/", data=None, extra_options=None
)

@pytest.mark.parametrize(
argnames="conn_id, account_id",
Expand All @@ -561,7 +614,9 @@ def test_get_job_run(self, mock_http_run, mock_paginate, conn_id, account_id):

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", data={"include_related": None}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/",
data={"include_related": None},
extra_options=None,
)
hook._paginate.assert_not_called()

Expand All @@ -580,7 +635,9 @@ def test_get_job_run_with_payload(self, mock_http_run, mock_paginate, conn_id, a

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", data={"include_related": ["triggers"]}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/",
data={"include_related": ["triggers"]},
extra_options=None,
)
hook._paginate.assert_not_called()

Expand Down Expand Up @@ -645,7 +702,7 @@ def test_cancel_job_run(self, mock_http_run, mock_paginate, conn_id, account_id)

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/cancel/", data=None
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/cancel/", data=None, extra_options=None
)
hook._paginate.assert_not_called()

Expand All @@ -664,7 +721,9 @@ def test_list_job_run_artifacts(self, mock_http_run, mock_paginate, conn_id, acc

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": None}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/",
data={"step": None},
extra_options=None,
)
hook._paginate.assert_not_called()

Expand All @@ -683,7 +742,9 @@ def test_list_job_run_artifacts_with_payload(self, mock_http_run, mock_paginate,

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": 2}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/",
data={"step": 2},
extra_options=None,
)
hook._paginate.assert_not_called()

Expand All @@ -703,7 +764,9 @@ def test_get_job_run_artifact(self, mock_http_run, mock_paginate, conn_id, accou

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": None}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}",
data={"step": None},
extra_options=None,
)
hook._paginate.assert_not_called()

Expand All @@ -723,7 +786,9 @@ def test_get_job_run_artifact_with_payload(self, mock_http_run, mock_paginate, c

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": 2}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}",
data={"step": 2},
extra_options=None,
)
hook._paginate.assert_not_called()

Expand Down

0 comments on commit c5afd71

Please sign in to comment.