From 227c9a524fb9ec2b4678d0486e0375e6ca0c5d9c Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Fri, 19 Apr 2024 17:00:05 +0545 Subject: [PATCH 1/9] Refactor Dataproc cluster cancellation handling in triggers --- .../google/cloud/operators/dataproc.py | 1 + .../google/cloud/triggers/dataproc.py | 103 +++++++++++++++--- .../google/cloud/triggers/test_dataproc.py | 1 + 3 files changed, 91 insertions(+), 14 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index edbfbd3f39b..e4fccfedd87 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -816,6 +816,7 @@ def execute(self, context: Context) -> dict: gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, polling_interval_seconds=self.polling_interval_seconds, + delete_on_error=self.delete_on_error, ), method_name="execute_complete", ) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index f0aecddb4a8..8708c9e2451 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -25,7 +25,7 @@ from typing import Any, AsyncIterator, Sequence from google.api_core.exceptions import NotFound -from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus +from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType @@ -43,6 +43,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, polling_interval_seconds: int = 30, + delete_on_error: bool = True, ): super().__init__() self.region = region @@ -50,6 +51,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.polling_interval_seconds = polling_interval_seconds + self.delete_on_error = delete_on_error def get_async_hook(self): return DataprocAsyncHook( @@ -140,24 +142,97 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "gcp_conn_id": self.gcp_conn_id, "impersonation_chain": self.impersonation_chain, "polling_interval_seconds": self.polling_interval_seconds, + "delete_on_error": self.delete_on_error, }, ) async def run(self) -> AsyncIterator[TriggerEvent]: - while True: - cluster = await self.get_async_hook().get_cluster( - project_id=self.project_id, region=self.region, cluster_name=self.cluster_name + """Run the trigger.""" + try: + while True: + cluster = await self.fetch_cluster_status() + if self.is_terminal_state(cluster.status.state): + if cluster.status.state == ClusterStatus.State.ERROR: + await self.gather_diagnostics_and_maybe_delete(cluster) + else: + yield TriggerEvent( + { + "cluster_name": self.cluster_name, + "cluster_state": cluster.status.state, + "cluster": cluster, + } + ) + break + self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds) + await asyncio.sleep(self.polling_interval_seconds) + except asyncio.CancelledError: + await self.handle_cancellation() + + async def fetch_cluster_status(self) -> Cluster: + """Fetch the cluster status.""" + return await self.get_async_hook().get_cluster( + project_id=self.project_id, region=self.region, cluster_name=self.cluster_name + ) + + def is_terminal_state(self, state: ClusterStatus.State) -> bool: + """ + Check if the state is terminal. + + :param state: The state of the cluster. + """ + return state in (ClusterStatus.State.ERROR, ClusterStatus.State.RUNNING) + + async def gather_diagnostics_and_maybe_delete(self, cluster: Cluster): + """ + Gather diagnostics and maybe delete the cluster. + + :param cluster: The cluster to gather diagnostics for. + """ + self.log.info("Cluster is in ERROR state. Gathering diagnostic information.") + try: + operation = await self.get_async_hook().diagnose_cluster( + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) - state = cluster.status.state - self.log.info("Dataproc cluster: %s is in state: %s", self.cluster_name, state) - if state in ( - ClusterStatus.State.ERROR, - ClusterStatus.State.RUNNING, - ): - break - self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds) - await asyncio.sleep(self.polling_interval_seconds) - yield TriggerEvent({"cluster_name": self.cluster_name, "cluster_state": state, "cluster": cluster}) + result = await operation.result() + gcs_uri = str(result.response.value) + self.log.info( + "Diagnostic information for cluster %s available at: %s", self.cluster_name, gcs_uri + ) + except Exception as e: + self.log.error("Failed to diagnose cluster: %s", e) + + if self.delete_on_error: + await self.get_async_hook().delete_cluster( + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id + ) + return TriggerEvent( + { + "cluster_name": self.cluster_name, + "cluster_state": cluster.status.state, + "cluster": None, + "action": "deleted", + } + ) + else: + return TriggerEvent( + {"cluster_name": self.cluster_name, "cluster_state": cluster.status.state, "cluster": cluster} + ) + + async def handle_cancellation(self) -> None: + """Handle the cancellation of the trigger, cleaning up resources if necessary.""" + self.log.info("Cancellation requested. Deleting the cluster if created.") + try: + if self.delete_on_error: + cluster = await self.fetch_cluster_status() + if cluster.status.state == ClusterStatus.State.ERROR: + await self.get_async_hook().async_delete_cluster( + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id + ) + self.log.info("Deleted cluster due to ERROR state during cancellation.") + else: + self.log.info("Cancellation did not require cluster deletion.") + except Exception as e: + self.log.error("Error during cancellation handling: %s", e) class DataprocBatchTrigger(DataprocBaseTrigger): diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index 45607d51b8a..404d05eda37 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -147,6 +147,7 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, c "gcp_conn_id": TEST_GCP_CONN_ID, "impersonation_chain": None, "polling_interval_seconds": TEST_POLL_INTERVAL, + "delete_on_error": True, } @pytest.mark.asyncio From 290e98160e0e60118800a7d3bce7aaae07cfbc86 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Mon, 22 Apr 2024 20:16:57 +0545 Subject: [PATCH 2/9] Add tests --- .../google/cloud/triggers/dataproc.py | 6 +- .../google/cloud/triggers/test_dataproc.py | 84 ++++++++++++++++--- 2 files changed, 74 insertions(+), 16 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 8708c9e2451..926b1f6d6ad 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -151,7 +151,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: try: while True: cluster = await self.fetch_cluster_status() - if self.is_terminal_state(cluster.status.state): + if self.check_cluster_state(cluster.status.state): if cluster.status.state == ClusterStatus.State.ERROR: await self.gather_diagnostics_and_maybe_delete(cluster) else: @@ -174,9 +174,9 @@ async def fetch_cluster_status(self) -> Cluster: project_id=self.project_id, region=self.region, cluster_name=self.cluster_name ) - def is_terminal_state(self, state: ClusterStatus.State) -> bool: + def check_cluster_state(self, state: ClusterStatus.State) -> bool: """ - Check if the state is terminal. + Check if the state is error or running. :param state: The state of the cluster. """ diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index 404d05eda37..5f67bea5896 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -24,6 +24,7 @@ import pytest from google.cloud.dataproc_v1 import Batch, ClusterStatus from google.protobuf.any_pb2 import Any +from google.rpc.error_details_pb2 import ErrorInfo from google.rpc.status_pb2 import Status from airflow.providers.google.cloud.triggers.dataproc import ( @@ -70,6 +71,7 @@ def batch_trigger(): gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=None, polling_interval_seconds=TEST_POLL_INTERVAL, + delete_on_error=True, ) return trigger @@ -96,6 +98,7 @@ def diagnose_operation_trigger(): gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=None, polling_interval_seconds=TEST_POLL_INTERVAL, + delete_on_error=True, ) @@ -176,27 +179,37 @@ async def test_async_cluster_triggers_on_success_should_execute_successfully( @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") + @mock.patch( + "airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster", + return_value=asyncio.Future(), + ) + @mock.patch("google.auth.default") async def test_async_cluster_trigger_run_returns_error_event( - self, mock_hook, cluster_trigger, async_get_cluster + self, mock_auth, mock_delete_cluster, mock_get_cluster, cluster_trigger, async_get_cluster, caplog ): - mock_hook.return_value = async_get_cluster( + mock_credentials = mock.MagicMock() + mock_credentials.universe_domain = "googleapis.com" + + mock_auth.return_value = (mock_credentials, "project-id") + + mock_delete_cluster.return_value = asyncio.Future() + mock_delete_cluster.return_value.set_result(None) + + mock_get_cluster.return_value = async_get_cluster( project_id=TEST_PROJECT_ID, region=TEST_REGION, cluster_name=TEST_CLUSTER_NAME, status=ClusterStatus(state=ClusterStatus.State.ERROR), ) - actual_event = await cluster_trigger.run().asend(None) - await asyncio.sleep(0.5) + caplog.set_level(logging.INFO) - expected_event = TriggerEvent( - { - "cluster_name": TEST_CLUSTER_NAME, - "cluster_state": ClusterStatus.State.ERROR, - "cluster": actual_event.payload["cluster"], - } - ) - assert expected_event == actual_event + trigger_event = None + async for event in cluster_trigger.run(): + trigger_event = event + + assert trigger_event is None, "Expected an event to be emitted" + assert "Cluster is in ERROR state. Gathering diagnostic information." in caplog.text @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") @@ -216,9 +229,54 @@ async def test_cluster_run_loop_is_still_running( await asyncio.sleep(0.5) assert not task.done() - assert f"Current state is: {ClusterStatus.State.CREATING}" + assert f"Current state is: {ClusterStatus.State.CREATING}." assert f"Sleeping for {TEST_POLL_INTERVAL} seconds." + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") + async def test_fetch_cluster_status(self, mock_get_cluster, cluster_trigger, async_get_cluster): + mock_get_cluster.return_value = async_get_cluster( + status=ClusterStatus(state=ClusterStatus.State.RUNNING) + ) + cluster = await cluster_trigger.fetch_cluster_status() + + assert cluster.status.state == ClusterStatus.State.RUNNING, "The cluster state should be RUNNING" + + def test_check_luster_state(self, cluster_trigger): + """Test if specific states are correctly identified.""" + assert cluster_trigger.check_cluster_state( + ClusterStatus.State.RUNNING + ), "RUNNING should be correct state" + assert cluster_trigger.check_cluster_state(ClusterStatus.State.ERROR), "ERROR should be correct state" + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.diagnose_cluster") + @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster") + async def test_gather_diagnostics_and_maybe_delete( + self, mock_delete_cluster, mock_diagnose_cluster, cluster_trigger, async_get_cluster + ): + error_info = ErrorInfo(reason="DIAGNOSTICS") + any_message = Any() + any_message.Pack(error_info) + + diagnose_future = asyncio.Future() + status = Status() + status.details.add().CopyFrom(any_message) + diagnose_future.set_result(status) + mock_diagnose_cluster.return_value = diagnose_future + + delete_future = asyncio.Future() + delete_future.set_result(None) + mock_delete_cluster.return_value = delete_future + + cluster = await async_get_cluster(status=ClusterStatus(state=ClusterStatus.State.ERROR)) + event = await cluster_trigger.gather_diagnostics_and_maybe_delete(cluster) + + mock_delete_cluster.assert_called_once() + assert ( + "deleted" in event.payload["action"] + ), "The cluster should be deleted due to error state and delete_on_error=True" + @pytest.mark.db_test class TestDataprocBatchTrigger: From aade3686f9ba4b69db27d95a2578a37a07d03060 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 23 Apr 2024 15:18:52 +0545 Subject: [PATCH 3/9] Fix PR comments --- .../google/cloud/triggers/dataproc.py | 105 +++++++++++------- .../google/cloud/triggers/test_dataproc.py | 11 +- 2 files changed, 67 insertions(+), 49 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 926b1f6d6ad..92581a0556e 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -27,7 +27,7 @@ from google.api_core.exceptions import NotFound from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus -from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook +from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -59,6 +59,12 @@ def get_async_hook(self): impersonation_chain=self.impersonation_chain, ) + def get_sync_hook(self): + return DataprocHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + class DataprocSubmitTrigger(DataprocBaseTrigger): """ @@ -150,39 +156,74 @@ async def run(self) -> AsyncIterator[TriggerEvent]: """Run the trigger.""" try: while True: - cluster = await self.fetch_cluster_status() - if self.check_cluster_state(cluster.status.state): - if cluster.status.state == ClusterStatus.State.ERROR: - await self.gather_diagnostics_and_maybe_delete(cluster) - else: - yield TriggerEvent( - { - "cluster_name": self.cluster_name, - "cluster_state": cluster.status.state, - "cluster": cluster, - } - ) + cluster = await self.fetch_cluster() + state = cluster.status.state + if state == ClusterStatus.State.ERROR: + await self.gather_diagnostics_and_delete_on_error(cluster) + break + elif state == ClusterStatus.State.RUNNING: + yield TriggerEvent( + { + "cluster_name": self.cluster_name, + "cluster_state": state, + "cluster": cluster, + } + ) break + self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds) await asyncio.sleep(self.polling_interval_seconds) except asyncio.CancelledError: - await self.handle_cancellation() + try: + if self.delete_on_error: + self.log.info("Deleting cluster %s.", self.cluster_name) + self.get_sync_hook().delete_cluster( + region=self.region, cluster_name=self.cluster_name, project_id=self.project_id + ) + self.log.info("Deleted cluster %s during cancellation.", self.cluster_name) + self.log.info("Cluster deletion initiated, awaiting completion...") + async for event in self.wait_until_cluster_deleted(): + if event["status"] == "success": + self.log.info("Cluster deletion confirmed.") + elif event["status"] == "error": + self.log.error("Cluster deletion failed with message: %s", event["message"]) + self.log.info("Finished handling cluster deletion.") + except Exception as e: + self.log.error("Error during cancellation handling: %s", e) + + async def wait_until_cluster_deleted(self): + """Wait until the cluster is confirmed as deleted.""" + end_time = time.time() + self.polling_interval_seconds * 10 # Set end time for loop + try: + while time.time() < end_time: + try: + await self.get_async_hook().get_cluster( + region=self.region, + cluster_name=self.cluster_name, + project_id=self.project_id, + ) + self.log.info( + "Cluster still exists. Sleeping for %s seconds.", self.polling_interval_seconds + ) + await asyncio.sleep(self.polling_interval_seconds) + except NotFound: + self.log.info("Cluster successfully deleted.") + yield TriggerEvent({"status": "success", "message": "Cluster deleted successfully."}) + return + except Exception as e: + self.log.error("Error while checking for cluster deletion: %s", e) + yield TriggerEvent({"status": "error", "message": str(e)}) + yield TriggerEvent( + {"status": "error", "message": "Timeout - cluster deletion not confirmed within expected time."} + ) - async def fetch_cluster_status(self) -> Cluster: + async def fetch_cluster(self) -> Cluster: """Fetch the cluster status.""" return await self.get_async_hook().get_cluster( project_id=self.project_id, region=self.region, cluster_name=self.cluster_name ) - def check_cluster_state(self, state: ClusterStatus.State) -> bool: - """ - Check if the state is error or running. - - :param state: The state of the cluster. - """ - return state in (ClusterStatus.State.ERROR, ClusterStatus.State.RUNNING) - - async def gather_diagnostics_and_maybe_delete(self, cluster: Cluster): + async def gather_diagnostics_and_delete_on_error(self, cluster: Cluster): """ Gather diagnostics and maybe delete the cluster. @@ -218,22 +259,6 @@ async def gather_diagnostics_and_maybe_delete(self, cluster: Cluster): {"cluster_name": self.cluster_name, "cluster_state": cluster.status.state, "cluster": cluster} ) - async def handle_cancellation(self) -> None: - """Handle the cancellation of the trigger, cleaning up resources if necessary.""" - self.log.info("Cancellation requested. Deleting the cluster if created.") - try: - if self.delete_on_error: - cluster = await self.fetch_cluster_status() - if cluster.status.state == ClusterStatus.State.ERROR: - await self.get_async_hook().async_delete_cluster( - region=self.region, cluster_name=self.cluster_name, project_id=self.project_id - ) - self.log.info("Deleted cluster due to ERROR state during cancellation.") - else: - self.log.info("Cancellation did not require cluster deletion.") - except Exception as e: - self.log.error("Error during cancellation handling: %s", e) - class DataprocBatchTrigger(DataprocBaseTrigger): """ diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index 5f67bea5896..3e36c9d9440 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -238,17 +238,10 @@ async def test_fetch_cluster_status(self, mock_get_cluster, cluster_trigger, asy mock_get_cluster.return_value = async_get_cluster( status=ClusterStatus(state=ClusterStatus.State.RUNNING) ) - cluster = await cluster_trigger.fetch_cluster_status() + cluster = await cluster_trigger.fetch_cluster() assert cluster.status.state == ClusterStatus.State.RUNNING, "The cluster state should be RUNNING" - def test_check_luster_state(self, cluster_trigger): - """Test if specific states are correctly identified.""" - assert cluster_trigger.check_cluster_state( - ClusterStatus.State.RUNNING - ), "RUNNING should be correct state" - assert cluster_trigger.check_cluster_state(ClusterStatus.State.ERROR), "ERROR should be correct state" - @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.diagnose_cluster") @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster") @@ -270,7 +263,7 @@ async def test_gather_diagnostics_and_maybe_delete( mock_delete_cluster.return_value = delete_future cluster = await async_get_cluster(status=ClusterStatus(state=ClusterStatus.State.ERROR)) - event = await cluster_trigger.gather_diagnostics_and_maybe_delete(cluster) + event = await cluster_trigger.gather_diagnostics_and_delete_on_error(cluster) mock_delete_cluster.assert_called_once() assert ( From c778078b5749f563e934905609309d11b13ccd42 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 23 Apr 2024 15:18:52 +0545 Subject: [PATCH 4/9] Fix PR comments --- .../providers/google/cloud/triggers/dataproc.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 92581a0556e..65f44c3e643 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -27,6 +27,7 @@ from google.api_core.exceptions import NotFound from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID @@ -181,7 +182,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) self.log.info("Deleted cluster %s during cancellation.", self.cluster_name) - self.log.info("Cluster deletion initiated, awaiting completion...") + self.log.info("Cluster deletion initiated.") async for event in self.wait_until_cluster_deleted(): if event["status"] == "success": self.log.info("Cluster deletion confirmed.") @@ -190,6 +191,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.log.info("Finished handling cluster deletion.") except Exception as e: self.log.error("Error during cancellation handling: %s", e) + raise AirflowException("Error during cancellation handling: %s", e) async def wait_until_cluster_deleted(self): """Wait until the cluster is confirmed as deleted.""" @@ -246,17 +248,10 @@ async def gather_diagnostics_and_delete_on_error(self, cluster: Cluster): await self.get_async_hook().delete_cluster( region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) - return TriggerEvent( - { - "cluster_name": self.cluster_name, - "cluster_state": cluster.status.state, - "cluster": None, - "action": "deleted", - } - ) + self.log.info("Cluster %s has been deleted.", self.cluster_name) else: - return TriggerEvent( - {"cluster_name": self.cluster_name, "cluster_state": cluster.status.state, "cluster": cluster} + self.log.info( + "Cluster %s is not be deleted as delete_on_error is set to False.", self.cluster_name ) From e3918ea22e17ce77fd6b41c6ad0c1d1be6edeb26 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Wed, 24 Apr 2024 11:45:53 +0545 Subject: [PATCH 5/9] Fix PR comments --- .../google/cloud/triggers/dataproc.py | 33 ------------------- .../google/cloud/triggers/test_dataproc.py | 14 +++++--- 2 files changed, 9 insertions(+), 38 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 65f44c3e643..1abbe2a19b4 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -182,43 +182,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) self.log.info("Deleted cluster %s during cancellation.", self.cluster_name) - self.log.info("Cluster deletion initiated.") - async for event in self.wait_until_cluster_deleted(): - if event["status"] == "success": - self.log.info("Cluster deletion confirmed.") - elif event["status"] == "error": - self.log.error("Cluster deletion failed with message: %s", event["message"]) - self.log.info("Finished handling cluster deletion.") except Exception as e: self.log.error("Error during cancellation handling: %s", e) raise AirflowException("Error during cancellation handling: %s", e) - async def wait_until_cluster_deleted(self): - """Wait until the cluster is confirmed as deleted.""" - end_time = time.time() + self.polling_interval_seconds * 10 # Set end time for loop - try: - while time.time() < end_time: - try: - await self.get_async_hook().get_cluster( - region=self.region, - cluster_name=self.cluster_name, - project_id=self.project_id, - ) - self.log.info( - "Cluster still exists. Sleeping for %s seconds.", self.polling_interval_seconds - ) - await asyncio.sleep(self.polling_interval_seconds) - except NotFound: - self.log.info("Cluster successfully deleted.") - yield TriggerEvent({"status": "success", "message": "Cluster deleted successfully."}) - return - except Exception as e: - self.log.error("Error while checking for cluster deletion: %s", e) - yield TriggerEvent({"status": "error", "message": str(e)}) - yield TriggerEvent( - {"status": "error", "message": "Timeout - cluster deletion not confirmed within expected time."} - ) - async def fetch_cluster(self) -> Cluster: """Fetch the cluster status.""" return await self.get_async_hook().get_cluster( diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index 3e36c9d9440..a3107bfd546 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -263,12 +263,16 @@ async def test_gather_diagnostics_and_maybe_delete( mock_delete_cluster.return_value = delete_future cluster = await async_get_cluster(status=ClusterStatus(state=ClusterStatus.State.ERROR)) - event = await cluster_trigger.gather_diagnostics_and_delete_on_error(cluster) + self.delete_on_error = True - mock_delete_cluster.assert_called_once() - assert ( - "deleted" in event.payload["action"] - ), "The cluster should be deleted due to error state and delete_on_error=True" + await cluster_trigger.gather_diagnostics_and_delete_on_error(cluster) + + mock_diagnose_cluster.assert_called_once_with( + region="region", cluster_name="cluster_name", project_id="project-id" + ) + mock_delete_cluster.assert_called_once_with( + region="region", cluster_name="cluster_name", project_id="project-id" + ) @pytest.mark.db_test From fe941940997c0cb6ab26e15442e4b3ce832fe6c8 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Wed, 24 Apr 2024 17:41:38 +0545 Subject: [PATCH 6/9] Fix the PR comments --- .../google/cloud/triggers/dataproc.py | 36 ++++++------- .../google/cloud/triggers/test_dataproc.py | 50 +++++++++---------- 2 files changed, 39 insertions(+), 47 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 1abbe2a19b4..f4861413b52 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -61,6 +61,7 @@ def get_async_hook(self): ) def get_sync_hook(self): + # The sync hook is used to delete the cluster in case of cancellation of task. return DataprocHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -160,8 +161,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: cluster = await self.fetch_cluster() state = cluster.status.state if state == ClusterStatus.State.ERROR: - await self.gather_diagnostics_and_delete_on_error(cluster) - break + await self.delete_when_error_occurred(cluster) + yield TriggerEvent( + { + "cluster_name": self.cluster_name, + "cluster_state": state.DELETING, + "cluster": cluster, + } + ) + return elif state == ClusterStatus.State.RUNNING: yield TriggerEvent( { @@ -170,14 +178,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: "cluster": cluster, } ) - break - + return self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds) await asyncio.sleep(self.polling_interval_seconds) except asyncio.CancelledError: try: if self.delete_on_error: self.log.info("Deleting cluster %s.", self.cluster_name) + # The sync hook is used to delete the cluster in case of cancellation of task. self.get_sync_hook().delete_cluster( region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) @@ -192,26 +200,14 @@ async def fetch_cluster(self) -> Cluster: project_id=self.project_id, region=self.region, cluster_name=self.cluster_name ) - async def gather_diagnostics_and_delete_on_error(self, cluster: Cluster): + async def delete_when_error_occurred(self, cluster: Cluster): """ - Gather diagnostics and maybe delete the cluster. + Delete the cluster on error. - :param cluster: The cluster to gather diagnostics for. + :param cluster: The cluster to delete. """ - self.log.info("Cluster is in ERROR state. Gathering diagnostic information.") - try: - operation = await self.get_async_hook().diagnose_cluster( - region=self.region, cluster_name=self.cluster_name, project_id=self.project_id - ) - result = await operation.result() - gcs_uri = str(result.response.value) - self.log.info( - "Diagnostic information for cluster %s available at: %s", self.cluster_name, gcs_uri - ) - except Exception as e: - self.log.error("Failed to diagnose cluster: %s", e) - if self.delete_on_error: + self.log.info("Deleting cluster %s.", self.cluster_name) await self.get_async_hook().delete_cluster( region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index a3107bfd546..6d1042d3126 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -22,9 +22,8 @@ from unittest import mock import pytest -from google.cloud.dataproc_v1 import Batch, ClusterStatus +from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus from google.protobuf.any_pb2 import Any -from google.rpc.error_details_pb2 import ErrorInfo from google.rpc.status_pb2 import Status from airflow.providers.google.cloud.triggers.dataproc import ( @@ -208,8 +207,8 @@ async def test_async_cluster_trigger_run_returns_error_event( async for event in cluster_trigger.run(): trigger_event = event - assert trigger_event is None, "Expected an event to be emitted" - assert "Cluster is in ERROR state. Gathering diagnostic information." in caplog.text + assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME + assert trigger_event.payload["cluster_state"] == ClusterStatus.State.DELETING @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") @@ -243,37 +242,34 @@ async def test_fetch_cluster_status(self, mock_get_cluster, cluster_trigger, asy assert cluster.status.state == ClusterStatus.State.RUNNING, "The cluster state should be RUNNING" @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.diagnose_cluster") @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster") - async def test_gather_diagnostics_and_maybe_delete( - self, mock_delete_cluster, mock_diagnose_cluster, cluster_trigger, async_get_cluster - ): - error_info = ErrorInfo(reason="DIAGNOSTICS") - any_message = Any() - any_message.Pack(error_info) - - diagnose_future = asyncio.Future() - status = Status() - status.details.add().CopyFrom(any_message) - diagnose_future.set_result(status) - mock_diagnose_cluster.return_value = diagnose_future + async def test_delete_when_error_occurred(self, mock_delete_cluster, cluster_trigger): + mock_cluster = mock.MagicMock(spec=Cluster) + type(mock_cluster).status = mock.PropertyMock( + return_value=mock.MagicMock(state=ClusterStatus.State.ERROR) + ) - delete_future = asyncio.Future() - delete_future.set_result(None) - mock_delete_cluster.return_value = delete_future + mock_delete_future = asyncio.Future() + mock_delete_future.set_result(None) + mock_delete_cluster.return_value = mock_delete_future - cluster = await async_get_cluster(status=ClusterStatus(state=ClusterStatus.State.ERROR)) - self.delete_on_error = True + cluster_trigger.delete_on_error = True - await cluster_trigger.gather_diagnostics_and_delete_on_error(cluster) + await cluster_trigger.delete_when_error_occurred(mock_cluster) - mock_diagnose_cluster.assert_called_once_with( - region="region", cluster_name="cluster_name", project_id="project-id" - ) mock_delete_cluster.assert_called_once_with( - region="region", cluster_name="cluster_name", project_id="project-id" + region=cluster_trigger.region, + cluster_name=cluster_trigger.cluster_name, + project_id=cluster_trigger.project_id, ) + mock_delete_cluster.reset_mock() + cluster_trigger.delete_on_error = False + + await cluster_trigger.delete_when_error_occurred(mock_cluster) + + mock_delete_cluster.assert_not_called() + @pytest.mark.db_test class TestDataprocBatchTrigger: From 7c28a2d5bbe03cbc049bb91f18aebe4749283c1f Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Wed, 24 Apr 2024 19:00:49 +0545 Subject: [PATCH 7/9] Fix the PR comments --- airflow/providers/google/cloud/triggers/dataproc.py | 12 +++++++++--- .../providers/google/cloud/triggers/test_dataproc.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index f4861413b52..5e082fc9cd7 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -61,7 +61,10 @@ def get_async_hook(self): ) def get_sync_hook(self): - # The sync hook is used to delete the cluster in case of cancellation of task. + # The synchronous hook is utilized to delete the cluster when a task is cancelled. + # This is because the asynchronous hook deletion is not awaited when the trigger task + # is cancelled. The call for deleting the cluster through the sync hook is not a blocking + # call, which means it does not wait until the cluster is deleted. return DataprocHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -165,7 +168,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: yield TriggerEvent( { "cluster_name": self.cluster_name, - "cluster_state": state.DELETING, + "cluster_state": state.ERROR, "cluster": cluster, } ) @@ -185,7 +188,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: try: if self.delete_on_error: self.log.info("Deleting cluster %s.", self.cluster_name) - # The sync hook is used to delete the cluster in case of cancellation of task. + # The synchronous hook is utilized to delete the cluster when a task is cancelled. + # This is because the asynchronous hook deletion is not awaited when the trigger task + # is cancelled. The call for deleting the cluster through the sync hook is not a blocking + # call, which means it does not wait until the cluster is deleted. self.get_sync_hook().delete_cluster( region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index 6d1042d3126..bc70aa6ff13 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -208,7 +208,7 @@ async def test_async_cluster_trigger_run_returns_error_event( trigger_event = event assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME - assert trigger_event.payload["cluster_state"] == ClusterStatus.State.DELETING + assert trigger_event.payload["cluster_state"] == ClusterStatus.State.ERROR @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") From f368278682ad5609723fdb5b9a55a256f7171492 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Wed, 24 Apr 2024 20:45:19 +0545 Subject: [PATCH 8/9] Fix the test --- .../google/cloud/triggers/dataproc.py | 3 +- .../google/cloud/triggers/test_dataproc.py | 47 ++++++++++++++++++- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 5e082fc9cd7..0b0e0c60d44 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -158,7 +158,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: ) async def run(self) -> AsyncIterator[TriggerEvent]: - """Run the trigger.""" try: while True: cluster = await self.fetch_cluster() @@ -168,7 +167,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: yield TriggerEvent( { "cluster_name": self.cluster_name, - "cluster_state": state.ERROR, + "cluster_state": ClusterStatus.State.DELETING, "cluster": cluster, } ) diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index bc70aa6ff13..e310f2e0dfc 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -208,7 +208,7 @@ async def test_async_cluster_trigger_run_returns_error_event( trigger_event = event assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME - assert trigger_event.payload["cluster_state"] == ClusterStatus.State.ERROR + assert trigger_event.payload["cluster_state"] == ClusterStatus.State.DELETING @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") @@ -231,6 +231,51 @@ async def test_cluster_run_loop_is_still_running( assert f"Current state is: {ClusterStatus.State.CREATING}." assert f"Sleeping for {TEST_POLL_INTERVAL} seconds." + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook") + async def test_cluster_trigger_cancellation_handling( + self, mock_get_sync_hook, mock_get_async_hook, caplog + ): + cluster = Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING)) + mock_get_async_hook.return_value.get_cluster.return_value = asyncio.Future() + mock_get_async_hook.return_value.get_cluster.return_value.set_result(cluster) + + mock_delete_cluster = mock.MagicMock() + mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster + + cluster_trigger = DataprocClusterTrigger( + cluster_name="cluster_name", + project_id="project-id", + region="region", + gcp_conn_id="google_cloud_default", + impersonation_chain=None, + polling_interval_seconds=5, + delete_on_error=True, + ) + + cluster_trigger_gen = cluster_trigger.run() + + try: + await cluster_trigger_gen.__anext__() + await cluster_trigger_gen.aclose() + + except asyncio.CancelledError: + # Verify that cancellation was handled as expected + if cluster_trigger.delete_on_error: + mock_get_sync_hook.assert_called_once() + mock_delete_cluster.assert_called_once_with( + region=cluster_trigger.region, + cluster_name=cluster_trigger.cluster_name, + project_id=cluster_trigger.project_id, + ) + assert "Deleting cluster" in caplog.text + assert "Deleted cluster" in caplog.text + else: + mock_delete_cluster.assert_not_called() + except Exception as e: + pytest.fail(f"Unexpected exception raised: {e}") + @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") async def test_fetch_cluster_status(self, mock_get_cluster, cluster_trigger, async_get_cluster): From 36aa2f8cf54cbe6545b9d01ab92c74627740ce9c Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Thu, 25 Apr 2024 16:38:08 +0545 Subject: [PATCH 9/9] Fix the PR comments --- airflow/providers/google/cloud/triggers/dataproc.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 0b0e0c60d44..32b536a2eca 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -181,6 +181,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: } ) return + self.log.info("Current state is %s", state) self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds) await asyncio.sleep(self.polling_interval_seconds) except asyncio.CancelledError: @@ -205,7 +206,7 @@ async def fetch_cluster(self) -> Cluster: project_id=self.project_id, region=self.region, cluster_name=self.cluster_name ) - async def delete_when_error_occurred(self, cluster: Cluster): + async def delete_when_error_occurred(self, cluster: Cluster) -> None: """ Delete the cluster on error. @@ -218,9 +219,7 @@ async def delete_when_error_occurred(self, cluster: Cluster): ) self.log.info("Cluster %s has been deleted.", self.cluster_name) else: - self.log.info( - "Cluster %s is not be deleted as delete_on_error is set to False.", self.cluster_name - ) + self.log.info("Cluster %s is not deleted as delete_on_error is set to False.", self.cluster_name) class DataprocBatchTrigger(DataprocBaseTrigger):