Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(*:skip) Fix client interceptor tests #4047

Merged
merged 11 commits into from
Aug 20, 2024
183 changes: 114 additions & 69 deletions src/py/flwr/client/grpc_rere_client/client_interceptor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import grpc

from flwr.client.grpc_rere_client.connection import grpc_request_response
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, serde
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.record import RecordSet
Expand All @@ -46,7 +46,9 @@
PushTaskResRequest,
PushTaskResResponse,
)
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611

from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request

Expand Down Expand Up @@ -75,15 +77,31 @@ def unary_unary(

if isinstance(request, CreateNodeRequest):
context.send_initial_metadata(
((_PUBLIC_KEY_HEADER, self.server_public_key),)
(
(
_PUBLIC_KEY_HEADER,
base64.urlsafe_b64encode(
public_key_to_bytes(self.server_public_key)
),
),
)
)
return CreateNodeResponse()
return CreateNodeResponse(node=Node(node_id=123))
if isinstance(request, DeleteNodeRequest):
return DeleteNodeResponse()
if isinstance(request, PushTaskResRequest):
return PushTaskResResponse()

return PullTaskInsResponse()
return PullTaskInsResponse(
task_ins_list=[
TaskIns(
task=Task(
consumer=Node(node_id=123),
recordset=serde.recordset_to_proto(RecordSet()),
)
)
]
)

def received_client_metadata(
self,
Expand Down Expand Up @@ -132,6 +150,16 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None:
server.add_generic_rpc_handlers((generic_handler,))


def _get_value_from_tuples(
key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
) -> bytes:
value = next((value for key, value in tuples if key == key_string), "")
if isinstance(value, str):
return value.encode()

return value


def _init_retry_invoker() -> RetryInvoker:
return RetryInvoker(
wait_gen_factory=exponential,
Expand Down Expand Up @@ -205,13 +233,20 @@ def test_client_auth_create_node(self) -> None:
_, _, create_node, _, _, _ = conn
assert create_node is not None
create_node()
expected_client_metadata = (
_PUBLIC_KEY_HEADER,
base64.urlsafe_b64encode(public_key_to_bytes(self._client_public_key)),

received_metadata = self._servicer.received_client_metadata()
assert received_metadata is not None

actual_public_key = _get_value_from_tuples(
_PUBLIC_KEY_HEADER, received_metadata
)

expected_public_key = base64.urlsafe_b64encode(
public_key_to_bytes(self._client_public_key)
)

# Assert
assert self._servicer.received_client_metadata() == expected_client_metadata
assert actual_public_key == expected_public_key

def test_client_auth_delete_node(self) -> None:
"""Test client authentication during delete node."""
Expand All @@ -227,30 +262,32 @@ def test_client_auth_delete_node(self) -> None:
None,
(self._client_private_key, self._client_public_key),
) as conn:
_, _, _, delete_node, _, _ = conn
_, _, create_node, delete_node, _, _ = conn
assert create_node is not None
create_node()
assert delete_node is not None
delete_node()

received_metadata = self._servicer.received_client_metadata()
assert received_metadata is not None

shared_secret = generate_shared_key(
self._servicer.server_private_key, self._client_public_key
)
expected_hmac = compute_hmac(
shared_secret, self._servicer.received_message_bytes()
expected_hmac = base64.urlsafe_b64encode(
compute_hmac(shared_secret, self._servicer.received_message_bytes())
)
expected_client_metadata = (
(
_PUBLIC_KEY_HEADER,
base64.urlsafe_b64encode(
public_key_to_bytes(self._client_public_key)
),
),
(
_AUTH_TOKEN_HEADER,
base64.urlsafe_b64encode(expected_hmac),
),
actual_public_key = _get_value_from_tuples(
_PUBLIC_KEY_HEADER, received_metadata
)
actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata)
expected_public_key = base64.urlsafe_b64encode(
public_key_to_bytes(self._client_public_key)
)

# Assert
assert self._servicer.received_client_metadata() == expected_client_metadata
assert actual_public_key == expected_public_key
assert actual_hmac == expected_hmac

def test_client_auth_receive(self) -> None:
"""Test client authentication during receive node."""
Expand All @@ -266,36 +303,38 @@ def test_client_auth_receive(self) -> None:
None,
(self._client_private_key, self._client_public_key),
) as conn:
receive, _, _, _, _, _ = conn
receive, _, create_node, _, _, _ = conn
assert create_node is not None
create_node()
assert receive is not None
receive()

received_metadata = self._servicer.received_client_metadata()
assert received_metadata is not None

shared_secret = generate_shared_key(
self._servicer.server_private_key, self._client_public_key
)
expected_hmac = compute_hmac(
shared_secret, self._servicer.received_message_bytes()
expected_hmac = base64.urlsafe_b64encode(
compute_hmac(shared_secret, self._servicer.received_message_bytes())
)
expected_client_metadata = (
(
_PUBLIC_KEY_HEADER,
base64.urlsafe_b64encode(
public_key_to_bytes(self._client_public_key)
),
),
(
_AUTH_TOKEN_HEADER,
base64.urlsafe_b64encode(expected_hmac),
),
actual_public_key = _get_value_from_tuples(
_PUBLIC_KEY_HEADER, received_metadata
)
actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata)
expected_public_key = base64.urlsafe_b64encode(
public_key_to_bytes(self._client_public_key)
)

# Assert
assert self._servicer.received_client_metadata() == expected_client_metadata
assert actual_public_key == expected_public_key
assert actual_hmac == expected_hmac

def test_client_auth_send(self) -> None:
"""Test client authentication during send node."""
# Prepare
retry_invoker = _init_retry_invoker()
message = Message(Metadata(0, "1", 0, 0, "", "", 0, ""), RecordSet())
message = Message(Metadata(0, "", 123, 0, "", "", 0, ""), RecordSet())

# Execute
with self._connection(
Expand All @@ -306,30 +345,34 @@ def test_client_auth_send(self) -> None:
None,
(self._client_private_key, self._client_public_key),
) as conn:
_, send, _, _, _, _ = conn
receive, send, create_node, _, _, _ = conn
assert create_node is not None
create_node()
assert receive is not None
receive()
assert send is not None
send(message)

received_metadata = self._servicer.received_client_metadata()
assert received_metadata is not None

shared_secret = generate_shared_key(
self._servicer.server_private_key, self._client_public_key
)
expected_hmac = compute_hmac(
shared_secret, self._servicer.received_message_bytes()
expected_hmac = base64.urlsafe_b64encode(
compute_hmac(shared_secret, self._servicer.received_message_bytes())
)
actual_public_key = _get_value_from_tuples(
_PUBLIC_KEY_HEADER, received_metadata
)
expected_client_metadata = (
(
_PUBLIC_KEY_HEADER,
base64.urlsafe_b64encode(
public_key_to_bytes(self._client_public_key)
),
),
(
_AUTH_TOKEN_HEADER,
base64.urlsafe_b64encode(expected_hmac),
),
actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata)
expected_public_key = base64.urlsafe_b64encode(
public_key_to_bytes(self._client_public_key)
)

# Assert
assert self._servicer.received_client_metadata() == expected_client_metadata
assert actual_public_key == expected_public_key
assert actual_hmac == expected_hmac

def test_client_auth_get_run(self) -> None:
"""Test client authentication during send node."""
Expand All @@ -345,30 +388,32 @@ def test_client_auth_get_run(self) -> None:
None,
(self._client_private_key, self._client_public_key),
) as conn:
_, _, _, _, get_run, _ = conn
_, _, create_node, _, get_run, _ = conn
assert create_node is not None
create_node()
assert get_run is not None
get_run(0)

received_metadata = self._servicer.received_client_metadata()
assert received_metadata is not None

shared_secret = generate_shared_key(
self._servicer.server_private_key, self._client_public_key
)
expected_hmac = compute_hmac(
shared_secret, self._servicer.received_message_bytes()
expected_hmac = base64.urlsafe_b64encode(
compute_hmac(shared_secret, self._servicer.received_message_bytes())
)
actual_public_key = _get_value_from_tuples(
_PUBLIC_KEY_HEADER, received_metadata
)
expected_client_metadata = (
(
_PUBLIC_KEY_HEADER,
base64.urlsafe_b64encode(
public_key_to_bytes(self._client_public_key)
),
),
(
_AUTH_TOKEN_HEADER,
base64.urlsafe_b64encode(expected_hmac),
),
actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata)
expected_public_key = base64.urlsafe_b64encode(
public_key_to_bytes(self._client_public_key)
)

# Assert
assert self._servicer.received_client_metadata() == expected_client_metadata
assert actual_public_key == expected_public_key
assert actual_hmac == expected_hmac


if __name__ == "__main__":
Expand Down