diff --git a/astronomer/providers/amazon/aws/hooks/redshift_data.py b/astronomer/providers/amazon/aws/hooks/redshift_data.py index 61f17441d..a9e658c39 100644 --- a/astronomer/providers/amazon/aws/hooks/redshift_data.py +++ b/astronomer/providers/amazon/aws/hooks/redshift_data.py @@ -77,6 +77,9 @@ def get_conn_params(self) -> Dict[str, Union[str, int]]: if "secret_access_key" in extra_config else extra_config["aws_secret_access_key"] ) + elif connection_object.login: + conn_params["aws_access_key_id"] = connection_object.login + conn_params["aws_secret_access_key"] = connection_object.password else: raise AirflowException("Required access_key_id, aws_secret_access_key") @@ -88,6 +91,12 @@ def get_conn_params(self) -> Dict[str, Union[str, int]]: else: raise AirflowException("Required Region name is missing !") + if "aws_session_token" in extra_config: + self.log.info( + "session token retrieved from extra, please note you are responsible for renewing these.", + ) + conn_params["aws_session_token"] = extra_config["aws_session_token"] + if "cluster_identifier" in extra_config: self.log.info("Retrieving cluster_identifier from Connection.extra_config['cluster_identifier']") conn_params["cluster_identifier"] = extra_config["cluster_identifier"] diff --git a/tests/amazon/aws/hooks/test_redshift_data.py b/tests/amazon/aws/hooks/test_redshift_data.py index fd06ac544..02b57b420 100644 --- a/tests/amazon/aws/hooks/test_redshift_data.py +++ b/tests/amazon/aws/hooks/test_redshift_data.py @@ -298,6 +298,47 @@ def test_get_conn_params(mock_get_connection, connection_details, expected_outpu assert response == expected_output +@pytest.mark.parametrize( + "mock_login, mock_pwd, connection_details, expected_output", + [ + ( + "test", + "test", + { + "db_user": "test_user", + "cluster_identifier": "test_cluster", + "region": "us-east-2", + "database": "test-redshift_database", + "aws_session_token": "test", + }, + { + "aws_access_key_id": "test", + "aws_secret_access_key": "test", + "aws_session_token": "test", + "db_user": "test_user", + "cluster_identifier": "test_cluster", + "region_name": "us-east-2", + "database": "test-redshift_database", + }, + ), + ], +) +@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.get_connection") +def test_get_conn_params_with_login_pwd( + mock_get_connection, mock_login, mock_pwd, connection_details, expected_output +): + """ + Test get_conn_params by mocking the AWS secret and access key and session token, + passing access and secret key in connection login and password instead passing in extra + """ + mock_conn = Connection(login=mock_login, password=mock_pwd, extra=json.dumps(connection_details)) + mock_get_connection.return_value = mock_conn + + hook = RedshiftDataHook(client_type="redshift-data") + response = hook.get_conn_params() + assert response == expected_output + + @pytest.mark.parametrize( "connection_details, test", [