diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index f5d3a6d2b6f..79416a8eb31 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -25,7 +25,7 @@ import grpc from flwr.client.grpc_rere_client.connection import grpc_request_response -from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common import GRPC_MAX_MESSAGE_LENGTH, serde from flwr.common.logger import log from flwr.common.message import Message, Metadata from flwr.common.record import RecordSet @@ -46,7 +46,9 @@ PushTaskResRequest, PushTaskResResponse, ) +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611 from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request @@ -75,15 +77,31 @@ def unary_unary( if isinstance(request, CreateNodeRequest): context.send_initial_metadata( - ((_PUBLIC_KEY_HEADER, self.server_public_key),) + ( + ( + _PUBLIC_KEY_HEADER, + base64.urlsafe_b64encode( + public_key_to_bytes(self.server_public_key) + ), + ), + ) ) - return CreateNodeResponse() + return CreateNodeResponse(node=Node(node_id=123)) if isinstance(request, DeleteNodeRequest): return DeleteNodeResponse() if isinstance(request, PushTaskResRequest): return PushTaskResResponse() - return PullTaskInsResponse() + return PullTaskInsResponse( + task_ins_list=[ + TaskIns( + task=Task( + consumer=Node(node_id=123), + recordset=serde.recordset_to_proto(RecordSet()), + ) + ) + ] + ) def received_client_metadata( self, @@ -132,6 +150,16 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: server.add_generic_rpc_handlers((generic_handler,)) +def _get_value_from_tuples( + key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] +) -> bytes: + value = next((value for key, value in tuples if key == key_string), "") + if isinstance(value, str): + return value.encode() + + return value + + def _init_retry_invoker() -> RetryInvoker: return RetryInvoker( wait_gen_factory=exponential, @@ -205,13 +233,20 @@ def test_client_auth_create_node(self) -> None: _, _, create_node, _, _, _ = conn assert create_node is not None create_node() - expected_client_metadata = ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode(public_key_to_bytes(self._client_public_key)), + + received_metadata = self._servicer.received_client_metadata() + assert received_metadata is not None + + actual_public_key = _get_value_from_tuples( + _PUBLIC_KEY_HEADER, received_metadata + ) + + expected_public_key = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) ) # Assert - assert self._servicer.received_client_metadata() == expected_client_metadata + assert actual_public_key == expected_public_key def test_client_auth_delete_node(self) -> None: """Test client authentication during delete node.""" @@ -227,30 +262,32 @@ def test_client_auth_delete_node(self) -> None: None, (self._client_private_key, self._client_public_key), ) as conn: - _, _, _, delete_node, _, _ = conn + _, _, create_node, delete_node, _, _ = conn + assert create_node is not None + create_node() assert delete_node is not None delete_node() + + received_metadata = self._servicer.received_client_metadata() + assert received_metadata is not None + shared_secret = generate_shared_key( self._servicer.server_private_key, self._client_public_key ) - expected_hmac = compute_hmac( - shared_secret, self._servicer.received_message_bytes() + expected_hmac = base64.urlsafe_b64encode( + compute_hmac(shared_secret, self._servicer.received_message_bytes()) ) - expected_client_metadata = ( - ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ), - ), - ( - _AUTH_TOKEN_HEADER, - base64.urlsafe_b64encode(expected_hmac), - ), + actual_public_key = _get_value_from_tuples( + _PUBLIC_KEY_HEADER, received_metadata + ) + actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) + expected_public_key = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) ) # Assert - assert self._servicer.received_client_metadata() == expected_client_metadata + assert actual_public_key == expected_public_key + assert actual_hmac == expected_hmac def test_client_auth_receive(self) -> None: """Test client authentication during receive node.""" @@ -266,36 +303,38 @@ def test_client_auth_receive(self) -> None: None, (self._client_private_key, self._client_public_key), ) as conn: - receive, _, _, _, _, _ = conn + receive, _, create_node, _, _, _ = conn + assert create_node is not None + create_node() assert receive is not None receive() + + received_metadata = self._servicer.received_client_metadata() + assert received_metadata is not None + shared_secret = generate_shared_key( self._servicer.server_private_key, self._client_public_key ) - expected_hmac = compute_hmac( - shared_secret, self._servicer.received_message_bytes() + expected_hmac = base64.urlsafe_b64encode( + compute_hmac(shared_secret, self._servicer.received_message_bytes()) ) - expected_client_metadata = ( - ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ), - ), - ( - _AUTH_TOKEN_HEADER, - base64.urlsafe_b64encode(expected_hmac), - ), + actual_public_key = _get_value_from_tuples( + _PUBLIC_KEY_HEADER, received_metadata + ) + actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) + expected_public_key = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) ) # Assert - assert self._servicer.received_client_metadata() == expected_client_metadata + assert actual_public_key == expected_public_key + assert actual_hmac == expected_hmac def test_client_auth_send(self) -> None: """Test client authentication during send node.""" # Prepare retry_invoker = _init_retry_invoker() - message = Message(Metadata(0, "1", 0, 0, "", "", 0, ""), RecordSet()) + message = Message(Metadata(0, "", 123, 0, "", "", 0, ""), RecordSet()) # Execute with self._connection( @@ -306,30 +345,34 @@ def test_client_auth_send(self) -> None: None, (self._client_private_key, self._client_public_key), ) as conn: - _, send, _, _, _, _ = conn + receive, send, create_node, _, _, _ = conn + assert create_node is not None + create_node() + assert receive is not None + receive() assert send is not None send(message) + + received_metadata = self._servicer.received_client_metadata() + assert received_metadata is not None + shared_secret = generate_shared_key( self._servicer.server_private_key, self._client_public_key ) - expected_hmac = compute_hmac( - shared_secret, self._servicer.received_message_bytes() + expected_hmac = base64.urlsafe_b64encode( + compute_hmac(shared_secret, self._servicer.received_message_bytes()) + ) + actual_public_key = _get_value_from_tuples( + _PUBLIC_KEY_HEADER, received_metadata ) - expected_client_metadata = ( - ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ), - ), - ( - _AUTH_TOKEN_HEADER, - base64.urlsafe_b64encode(expected_hmac), - ), + actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) + expected_public_key = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) ) # Assert - assert self._servicer.received_client_metadata() == expected_client_metadata + assert actual_public_key == expected_public_key + assert actual_hmac == expected_hmac def test_client_auth_get_run(self) -> None: """Test client authentication during send node.""" @@ -345,30 +388,32 @@ def test_client_auth_get_run(self) -> None: None, (self._client_private_key, self._client_public_key), ) as conn: - _, _, _, _, get_run, _ = conn + _, _, create_node, _, get_run, _ = conn + assert create_node is not None + create_node() assert get_run is not None get_run(0) + + received_metadata = self._servicer.received_client_metadata() + assert received_metadata is not None + shared_secret = generate_shared_key( self._servicer.server_private_key, self._client_public_key ) - expected_hmac = compute_hmac( - shared_secret, self._servicer.received_message_bytes() + expected_hmac = base64.urlsafe_b64encode( + compute_hmac(shared_secret, self._servicer.received_message_bytes()) + ) + actual_public_key = _get_value_from_tuples( + _PUBLIC_KEY_HEADER, received_metadata ) - expected_client_metadata = ( - ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ), - ), - ( - _AUTH_TOKEN_HEADER, - base64.urlsafe_b64encode(expected_hmac), - ), + actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) + expected_public_key = base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) ) # Assert - assert self._servicer.received_client_metadata() == expected_client_metadata + assert actual_public_key == expected_public_key + assert actual_hmac == expected_hmac if __name__ == "__main__":