Skip to content

Commit

Permalink
added MultipleFilesWebHdfsSensor (#43045)
Browse files Browse the repository at this point in the history
  • Loading branch information
eilon246810 authored Oct 17, 2024
1 parent b010acd commit 3b9e156
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 3 deletions.
32 changes: 31 additions & 1 deletion providers/src/airflow/providers/apache/hdfs/sensors/web_hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any, Sequence, Union

from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context
from hdfs.ext.kerberos import KerberosClient
from hdfs import InsecureClient


class WebHdfsSensor(BaseSensorOperator):
Expand All @@ -41,3 +43,31 @@ def poke(self, context: Context) -> bool:
hook = WebHDFSHook(self.webhdfs_conn_id)
self.log.info("Poking for file %s", self.filepath)
return hook.check_for_path(hdfs_path=self.filepath)


class MultipleFilesWebHdfsSensor(BaseSensorOperator):
"""Waits for multiple files in a folder to land in HDFS."""

template_fields: Sequence[str] = ("directory_path", "expected_filenames")

def __init__(self, *, directory_path: str, expected_filenames: Sequence[str],
webhdfs_conn_id: str = "webhdfs_default", **kwargs: Any) -> None:
super().__init__(**kwargs)
self.directory_path = directory_path
self.expected_filenames = expected_filenames
self.webhdfs_conn_id = webhdfs_conn_id

def poke(self, context: Context) -> bool:
from airflow.providers.apache.hdfs.hooks.webhdfs import WebHDFSHook

hook = WebHDFSHook(self.webhdfs_conn_id)
conn: 'KerberosClient | InsecureClient' = hook.get_conn()

actual_files = set(conn.list(self.directory_path))
self.log.debug("Files Found in directory: %s", actual_files)

missing_files = set(self.expected_filenames) - actual_files
if missing_files:
self.log.info("There are missing files: %s", missing_files)
return False
return True
48 changes: 46 additions & 2 deletions providers/tests/apache/hdfs/sensors/test_web_hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
# under the License.
from __future__ import annotations

import os
from unittest import mock

from airflow.providers.apache.hdfs.sensors.web_hdfs import WebHdfsSensor
from airflow.providers.apache.hdfs.sensors.web_hdfs import WebHdfsSensor, MultipleFilesWebHdfsSensor

TEST_HDFS_CONN = "webhdfs_default"
TEST_HDFS_PATH = "hdfs://user/hive/warehouse/airflow.db/static_babynames"
TEST_HDFS_DIRECTORY = "hdfs://user/hive/warehouse/airflow.db"
TEST_HDFS_FILENAMES = ["static_babynames1", "static_babynames2", "static_babynames3"]
TEST_HDFS_PATH = os.path.join(TEST_HDFS_DIRECTORY, TEST_HDFS_FILENAMES[0])


class TestWebHdfsSensor:
Expand Down Expand Up @@ -55,3 +58,44 @@ def test_poke_should_return_false_for_non_existing_table(self, mock_hook):

mock_hook.return_value.check_for_path.assert_called_once_with(hdfs_path=TEST_HDFS_PATH)
mock_hook.assert_called_once_with(TEST_HDFS_CONN)


class TestMultipleFilesWebHdfsSensor:
@mock.patch("airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook")
def test_poke(self, mock_hook, caplog):
mock_hook.return_value.get_conn.return_value.list.return_value = TEST_HDFS_FILENAMES

sensor = MultipleFilesWebHdfsSensor(
task_id="test_task",
webhdfs_conn_id=TEST_HDFS_CONN,
directory_path=TEST_HDFS_DIRECTORY,
expected_filenames=TEST_HDFS_FILENAMES
)
result = sensor.poke(dict())

assert result
assert "Files Found in directory: " in caplog.text

mock_hook.return_value.get_conn.return_value.list.assert_called_once_with(TEST_HDFS_DIRECTORY)
mock_hook.return_value.get_conn.assert_called_once()
mock_hook.assert_called_once_with(TEST_HDFS_CONN)

@mock.patch("airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook")
def test_poke_should_return_false_for_missing_file(self, mock_hook, caplog):
mock_hook.return_value.get_conn.return_value.list.return_value = TEST_HDFS_FILENAMES[0]

sensor = MultipleFilesWebHdfsSensor(
task_id="test_task",
webhdfs_conn_id=TEST_HDFS_CONN,
directory_path=TEST_HDFS_DIRECTORY,
expected_filenames=TEST_HDFS_FILENAMES
)
exists = sensor.poke(dict())

assert not exists
assert "Files Found in directory: " in caplog.text
assert "There are missing files: " in caplog.text

mock_hook.return_value.get_conn.return_value.list.assert_called_once_with(TEST_HDFS_DIRECTORY)
mock_hook.return_value.get_conn.assert_called_once()
mock_hook.assert_called_once_with(TEST_HDFS_CONN)

0 comments on commit 3b9e156

Please sign in to comment.