Skip to content

Commit

Permalink
Feature: Added event_handler parameter in MSGraphAsyncOperator (#42539)
Browse files Browse the repository at this point in the history
* refactor: Added parameter in MSGraphAsyncOperator to allow overriding default event_handler

* docs: Added docstring for event_handler parameter in MSGraphAsyncOperator

* refactor: Fixed TestMSGraphAsyncOperator

* refactor: Check if event is not None

* refactor: Register the TextParseNodeFactory and JsonParseNodeFactory so error messages get handled correctly in RequestAdapter

* refactor: Reorganized import for TestMSGraphAsyncOperator

* refactor: Added missing kiota-serialization packages in azure provider

* refactor: Updated provider dependencies

* refactor: Reorganized import of TestKiotaRequestAdapterHook

* refactor: Downgraded version of json kiota serialization

* refactor: Updated provider dependencies

* refactor: Put import of Context in TYPE_CHECKING block

* refactor: Fixed lookup of tenant-id

* refactor: Fixed kiota serialization dependencies to 1.0.0 to avoid pendulum dependency issues for backward compatibility

* refactor: Updated provider dependencies

* refactored: Fixed import of test_utils in test_dag_run

---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
dabla and davidblain-infrabel authored Oct 15, 2024
1 parent 4c5ad9c commit 2b101e2
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 5 deletions.
2 changes: 2 additions & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,8 @@
"azure-synapse-artifacts>=0.17.0",
"azure-synapse-spark>=0.2.0",
"microsoft-kiota-http>=1.3.0,!=1.3.4",
"microsoft-kiota-serialization-json==1.0.0",
"microsoft-kiota-serialization-text==1.0.0",
"msgraph-core>=1.0.0"
],
"devel-deps": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@
from kiota_abstractions.method import Method
from kiota_abstractions.request_information import RequestInformation
from kiota_abstractions.response_handler import ResponseHandler
from kiota_abstractions.serialization import ParseNodeFactoryRegistry
from kiota_authentication_azure.azure_identity_authentication_provider import (
AzureIdentityAuthenticationProvider,
)
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
from kiota_http.middleware.options import ResponseHandlerOption
from kiota_serialization_json.json_parse_node_factory import JsonParseNodeFactory
from kiota_serialization_text.text_parse_node_factory import TextParseNodeFactory
from msgraph_core import APIVersion, GraphClientFactory
from msgraph_core._enums import NationalClouds

Expand Down Expand Up @@ -249,8 +252,12 @@ def get_conn(self) -> RequestAdapter:
scopes=scopes,
allowed_hosts=allowed_hosts,
)
parse_node_factory = ParseNodeFactoryRegistry()
parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["text/plain"] = TextParseNodeFactory()
parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["application/json"] = JsonParseNodeFactory()
request_adapter = HttpxRequestAdapter(
authentication_provider=auth_provider,
parse_node_factory=parse_node_factory,
http_client=http_client,
base_url=base_url,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@
from airflow.utils.context import Context


def default_event_handler(context: Context, event: dict[Any, Any] | None = None) -> Any:
if event:
if event.get("status") == "failure":
raise AirflowException(event.get("message"))

return event.get("response")


class MSGraphAsyncOperator(BaseOperator):
"""
A Microsoft Graph API operator which allows you to execute REST call to the Microsoft Graph API.
Expand All @@ -69,6 +77,9 @@ class MSGraphAsyncOperator(BaseOperator):
:param result_processor: Function to further process the response from MS Graph API
(default is lambda: context, response: response). When the response returned by the
`KiotaRequestAdapterHook` are bytes, then those will be base64 encoded into a string.
:param event_handler: Function to process the event returned from `MSGraphTrigger`. By default, when the
event returned by the `MSGraphTrigger` has a failed status, an AirflowException is being raised with
the message from the event, otherwise the response from the event payload is returned.
:param serializer: Class which handles response serialization (default is ResponseSerializer).
Bytes will be base64 encoded into a string, so it can be stored as an XCom.
"""
Expand Down Expand Up @@ -102,6 +113,7 @@ def __init__(
api_version: APIVersion | str | None = None,
pagination_function: Callable[[MSGraphAsyncOperator, dict, Context], tuple[str, dict]] | None = None,
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
event_handler: Callable[[Context, dict[Any, Any] | None], Any] | None = None,
serializer: type[ResponseSerializer] = ResponseSerializer,
**kwargs: Any,
):
Expand All @@ -121,6 +133,7 @@ def __init__(
self.api_version = api_version
self.pagination_function = pagination_function or self.paginate
self.result_processor = result_processor
self.event_handler = event_handler or default_event_handler
self.serializer: ResponseSerializer = serializer()

def execute(self, context: Context) -> None:
Expand Down Expand Up @@ -158,10 +171,7 @@ def execute_complete(
if event:
self.log.debug("%s completed with %s: %s", self.task_id, event.get("status"), event)

if event.get("status") == "failure":
raise AirflowException(event.get("message"))

response = event.get("response")
response = self.event_handler(context, event)

self.log.debug("response: %s", response)

Expand Down
2 changes: 2 additions & 0 deletions providers/src/airflow/providers/microsoft/azure/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ dependencies:
# msgraph-core has transient import failures with microsoft-kiota-http==1.3.4
# See https:/microsoftgraph/msgraph-sdk-python-core/issues/706
- microsoft-kiota-http>=1.3.0,!=1.3.4
- microsoft-kiota-serialization-json==1.0.0
- microsoft-kiota-serialization-text==1.0.0

devel-dependencies:
- pywinrm
Expand Down
45 changes: 44 additions & 1 deletion providers/tests/microsoft/azure/hooks/test_msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@
import asyncio
from json import JSONDecodeError
from typing import TYPE_CHECKING
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
from httpx import Response
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
from kiota_serialization_json.json_parse_node import JsonParseNode
from kiota_serialization_text.text_parse_node import TextParseNode
from msgraph_core import APIVersion, NationalClouds
from opentelemetry.trace import Span

from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFoundException
from airflow.providers.microsoft.azure.hooks.msgraph import (
Expand Down Expand Up @@ -175,6 +179,45 @@ def test_encoded_query_parameters(self):

assert actual == {"%24expand": "reports,users,datasets,dataflows,dashboards", "%24top": 5000}

@pytest.mark.asyncio
async def test_throw_failed_responses_with_text_plain_content_type(self):
with patch(
"airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection,
):
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
response = Mock(spec=Response)
response.headers = {"content-type": "text/plain"}
response.status_code = 429
response.content = b"TenantThrottleThresholdExceeded"
response.is_success = False
span = Mock(spec=Span)

actual = await hook.get_conn().get_root_parse_node(response, span, span)

assert isinstance(actual, TextParseNode)
assert actual.get_str_value() == "TenantThrottleThresholdExceeded"

@pytest.mark.asyncio
async def test_throw_failed_responses_with_application_json_content_type(self):
with patch(
"airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection,
):
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
response = Mock(spec=Response)
response.headers = {"content-type": "application/json"}
response.status_code = 429
response.content = b'{"error": {"code": "TenantThrottleThresholdExceeded"}}'
response.is_success = False
span = Mock(spec=Span)

actual = await hook.get_conn().get_root_parse_node(response, span, span)

assert isinstance(actual, JsonParseNode)
error_code = actual.get_child_node("error").get_child_node("code").get_str_value()
assert error_code == "TenantThrottleThresholdExceeded"


class TestResponseHandler:
def test_default_response_handler_when_json(self):
Expand Down
30 changes: 30 additions & 0 deletions providers/tests/microsoft/azure/operators/test_msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
import locale
from base64 import b64encode
from typing import TYPE_CHECKING, Any

import pytest

Expand All @@ -35,6 +36,9 @@
mock_response,
)

if TYPE_CHECKING:
from airflow.utils.context import Context


class TestMSGraphAsyncOperator(Base):
@pytest.mark.db_test
Expand Down Expand Up @@ -101,6 +105,32 @@ def test_execute_when_an_exception_occurs(self):
with pytest.raises(AirflowException):
self.execute_operator(operator)

@pytest.mark.db_test
def test_execute_when_an_exception_occurs_on_custom_event_handler(self):
with self.patch_hook_and_request_adapter(AirflowException("An error occurred")):

def custom_event_handler(context: Context, event: dict[Any, Any] | None = None):
if event:
if event.get("status") == "failure":
return None

return event.get("response")

operator = MSGraphAsyncOperator(
task_id="users_delta",
conn_id="msgraph_api",
url="users/delta",
event_handler=custom_event_handler,
)

results, events = self.execute_operator(operator)

assert not results
assert len(events) == 1
assert isinstance(events[0], TriggerEvent)
assert events[0].payload["status"] == "failure"
assert events[0].payload["message"] == "An error occurred"

@pytest.mark.db_test
def test_execute_when_response_is_bytes(self):
content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
Expand Down

0 comments on commit 2b101e2

Please sign in to comment.