Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add keep_response parameter to HttpHookAsync #1330

Merged
merged 9 commits into from
Sep 25, 2023
65 changes: 39 additions & 26 deletions astronomer/providers/http/hooks/http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable

import aiohttp
from aiohttp import ClientResponseError
Expand All @@ -20,6 +22,9 @@ class HttpHookAsync(BaseHook):
API url i.e https://www.google.com/ and optional authentication credentials. Default
headers can also be specified in the Extra field in json format.
:param auth_type: The auth type for the service
:param keep_response: Keep the aiohttp response returned by run method without releasing it.
Use it with caution. Without properly releasing response, it might cause "Unclosed connection" error.
See https:/astronomer/astronomer-providers/issues/909
:type auth_type: AuthBase of python aiohttp lib
"""

Expand All @@ -35,6 +40,8 @@ def __init__(
auth_type: Any = aiohttp.BasicAuth,
retry_limit: int = 3,
retry_delay: float = 1.0,
*,
keep_response: bool = False,
) -> None:
self.http_conn_id = http_conn_id
self.method = method.upper()
Expand All @@ -45,14 +52,15 @@ def __init__(
raise ValueError("Retry limit must be greater than equal to 1")
self.retry_limit = retry_limit
self.retry_delay = retry_delay
self.keep_response = keep_response

async def run(
self,
endpoint: Optional[str] = None,
data: Optional[Union[Dict[str, Any], str]] = None,
headers: Optional[Dict[str, Any]] = None,
extra_options: Optional[Dict[str, Any]] = None,
) -> "ClientResponse":
endpoint: str | None = None,
data: dict[str, Any] | str | None = None,
headers: dict[str, Any] | None = None,
extra_options: dict[str, Any] | None = None,
) -> ClientResponse:
r"""
Performs an asynchronous HTTP request call

Expand All @@ -78,10 +86,10 @@ async def run(
# schema defaults to HTTP
schema = conn.schema if conn.schema else "http"
host = conn.host if conn.host else ""
self.base_url = schema + "://" + host
self.base_url = f"{schema}://{host}"

if conn.port:
self.base_url = self.base_url + ":" + str(conn.port)
self.base_url = f"{self.base_url}:{conn.port}"
if conn.login:
auth = self.auth_type(conn.login, conn.password)
if conn.extra:
Expand All @@ -93,7 +101,7 @@ async def run(
_headers.update(headers)

if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
url = self.base_url + "/" + endpoint
url = f"{self.base_url}/{endpoint}"
else:
url = (self.base_url or "") + (endpoint or "")

Expand All @@ -109,29 +117,34 @@ async def run(

attempt_num = 1
while True:
async with request_func(
response = await request_func(
url,
json=data if self.method in ("POST", "PATCH") else None,
params=data if self.method == "GET" else None,
headers=headers,
auth=auth,
**extra_options,
) as response:
try:
response.raise_for_status()
return response
except ClientResponseError as e:
self.log.warning(
"[Try %d of %d] Request to %s failed.",
attempt_num,
self.retry_limit,
url,
)
if not self._retryable_error_async(e) or attempt_num == self.retry_limit:
self.log.exception("HTTP error with status: %s", e.status)
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException(str(e.status) + ":" + e.message)
)
try:
response.raise_for_status()
if not self.keep_response:
response.release()
return response
except ClientResponseError as e:
self.log.warning(
"[Try %d of %d] Request to %s failed.",
attempt_num,
self.retry_limit,
url,
)
if not self._retryable_error_async(e) or attempt_num == self.retry_limit:
self.log.exception("HTTP error with status: %s", e.status)
response.release()
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException(f"{e.status}:{e.message}")

response.release()

attempt_num += 1
await asyncio.sleep(self.retry_delay)
Expand Down
4 changes: 2 additions & 2 deletions astronomer/providers/snowflake/hooks/snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def execute_query(
response.raise_for_status()
except requests.exceptions.HTTPError as e: # pragma: no cover
raise AirflowException(
f"Response: {e.response.content}, " f"Status Code: {e.response.status_code}"
f"Response: {e.response.content!r}, " f"Status Code: {e.response.status_code}"
) # pragma: no cover
json_response = response.json()
self.log.info("Snowflake SQL POST API response: %s", json_response)
Expand Down Expand Up @@ -204,7 +204,7 @@ def check_query_output(self, query_ids: list[str]) -> None:
self.log.info(response.json())
except requests.exceptions.HTTPError as e:
raise AirflowException(
f"Response: {e.response.content}, Status Code: {e.response.status_code}"
f"Response: {e.response.content!r}, Status Code: {e.response.status_code}"
)

@staticmethod
Expand Down
43 changes: 42 additions & 1 deletion tests/http/hooks/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest import mock

import pytest
from aiohttp.client_exceptions import ClientConnectionError
from airflow.exceptions import AirflowException
from airflow.models import Connection

Expand Down Expand Up @@ -54,7 +55,8 @@ def get_airflow_connection(unused_conn_id=None):
return Connection(
conn_id="http_default",
conn_type="http",
host="test:8080/",
host="test",
port=8080,
extra='{"bearer": "test"}',
)

Expand All @@ -75,6 +77,45 @@ async def test_post_request(self, aioresponse):
resp = await hook.run("v1/test")
assert resp.status == 200

@pytest.mark.asyncio
async def test_post_request_and_get_json_without_keep_response(self, aioresponse):
hook = HttpHookAsync()
payload = '{"status":{"status": 200}}'

aioresponse.post(
"http://test:8080/v1/test",
status=200,
payload=payload,
reason="OK",
)

with mock.patch(
"airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_airflow_connection
):
resp = await hook.run("v1/test")
with pytest.raises(ClientConnectionError, match="Connection closed"):
await resp.json()

@pytest.mark.asyncio
async def test_post_request_and_get_json_with_keep_response(self, aioresponse):
hook = HttpHookAsync(keep_response=True)
payload = '{"status":{"status": 200}}'

aioresponse.post(
"http://test:8080/v1/test",
status=200,
payload=payload,
reason="OK",
)

with mock.patch(
"airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_airflow_connection
):
resp = await hook.run("v1/test")
resp_payload = await resp.json()
assert resp.status == 200
assert resp_payload == payload

@pytest.mark.asyncio
async def test_post_request_with_error_code(self, aioresponse):
hook = HttpHookAsync()
Expand Down
Loading