Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add client_fn argument to start_client #2303

Merged
merged 18 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

### What's new?

- **Unify client API** ([#2303](https:/adap/flower/pull/2303))

Using the `client_fn`, Flower clients can interchangeably run as standalone processes (i.e. via `start_client`) or in simulation (i.e. via `start_simulation`) without requiring changes to how the client class is defined and instantiated.

- **General updates to baselines** ([#2301](https:/adap/flower/pull/2301).[#2305](https:/adap/flower/pull/2305), [#2307](https:/adap/flower/pull/2307), [#2327](https:/adap/flower/pull/2327))

- **General improvements** ([#2309](https:/adap/flower/pull/2309), [#2310](https:/adap/flower/pull/2310), [2313](https:/adap/flower/pull/2313), [#2316](https:/adap/flower/pull/2316), [2317](https:/adap/flower/pull/2317),[#2349](https:/adap/flower/pull/2349))
Expand Down
82 changes: 66 additions & 16 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import sys
import time
from logging import INFO
from typing import Optional, Union
from typing import Callable, Optional, Union

from flwr.client.typing import ClientFn, ClientLike
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
from flwr.common.constant import (
Expand All @@ -31,19 +32,32 @@
)
from flwr.common.logger import log

from .client import Client
from .grpc_client.connection import grpc_connection
from .grpc_rere_client.connection import grpc_request_response
from .message_handler.message_handler import handle
from .numpy_client import NumPyClient
from .numpy_client_wrapper import _wrap_numpy_client


def _check_actionable_client(
client: Optional[ClientLike], client_fn: Optional[ClientFn]
) -> None:
if client_fn is None and client is None:
raise Exception("Both `client_fn` and `client` are `None`, but one is required")

if client_fn is not None and client is not None:
raise Exception(
"Both `client_fn` and `client` are provided, but only one is allowed"
)


# pylint: disable=import-outside-toplevel,too-many-locals,too-many-branches
# pylint: disable=too-many-statements
def start_client(
*,
server_address: str,
client: Client,
client_fn: Optional[ClientFn] = None,
client: Optional[ClientLike] = None,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[Union[bytes, str]] = None,
transport: Optional[str] = None,
Expand All @@ -56,9 +70,11 @@ def start_client(
The IPv4 or IPv6 address of the server. If the Flower
server runs on the same machine on port 8080, then `server_address`
would be `"[::]:8080"`.
client : flwr.client.Client
client_fn : Optional[ClientFn]
A callable that instantiates a Client. (default: None)
client : Optional[flwr.client.Client]
An implementation of the abstract base
class `flwr.client.Client`.
class `flwr.client.Client` (default: None)
grpc_max_message_length : int (default: 536_870_912, this equals 512MB)
The maximum length of gRPC messages that can be exchanged with the
Flower server. The default should be sufficient for most models.
Expand All @@ -80,22 +96,43 @@ class `flwr.client.Client`.
--------
Starting a gRPC client with an insecure server connection:

>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_client(
>>> server_address=localhost:8080,
>>> client=FlowerClient(),
>>> client_fn=client_fn,
>>> )

Starting an SSL-enabled gRPC client:

>>> from pathlib import Path
>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_client(
>>> server_address=localhost:8080,
>>> client=FlowerClient(),
>>> client_fn=client_fn,
>>> root_certificates=Path("/crts/root.pem").read_bytes(),
>>> )
"""
event(EventType.START_CLIENT_ENTER)

_check_actionable_client(client, client_fn)

if client_fn is None:
# Wrap `Client` instance in `client_fn`
def single_client_factory(
cid: str, # pylint: disable=unused-argument
) -> ClientLike:
if client is None: # Added this to keep mypy happy
raise Exception(
"Both `client_fn` and `client` are `None`, but one is required"
)
return client # Always return the same instance

client_fn = single_client_factory

# Parse IP address
parsed_address = parse_address(server_address)
if not parsed_address:
Expand Down Expand Up @@ -146,7 +183,7 @@ class `flwr.client.Client`.
if task_ins is None:
time.sleep(3) # Wait for 3s before asking again
continue
task_res, sleep_duration, keep_going = handle(client, task_ins)
task_res, sleep_duration, keep_going = handle(client_fn, task_ins)
send(task_res)
if not keep_going:
break
Expand All @@ -172,7 +209,8 @@ class `flwr.client.Client`.
def start_numpy_client(
*,
server_address: str,
client: NumPyClient,
client_fn: Optional[Callable[[str], NumPyClient]] = None,
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
client: Optional[NumPyClient] = None,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[bytes] = None,
transport: Optional[str] = None,
Expand All @@ -185,7 +223,9 @@ def start_numpy_client(
The IPv4 or IPv6 address of the server. If the Flower server runs on
the same machine on port 8080, then `server_address` would be
`"[::]:8080"`.
client : flwr.client.NumPyClient
client_fn : Optional[Callable[[str], NumPyClient]]
A callable that instantiates a NumPyClient. (default: None)
client : Optional[flwr.client.NumPyClient]
An implementation of the abstract base class `flwr.client.NumPyClient`.
grpc_max_message_length : int (default: 536_870_912, this equals 512MB)
The maximum length of gRPC messages that can be exchanged with the
Expand All @@ -208,24 +248,34 @@ def start_numpy_client(
--------
Starting a client with an insecure server connection:

>>> start_client(
>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_numpy_client(
>>> server_address=localhost:8080,
>>> client=FlowerClient(),
>>> client_fn=client_fn,
>>> )

Starting a SSL-enabled client:
Starting an SSL-enabled gRPC client:

>>> from pathlib import Path
>>> start_client(
>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_numpy_client(
>>> server_address=localhost:8080,
>>> client=FlowerClient(),
>>> client_fn=client_fn,
>>> root_certificates=Path("/crts/root.pem").read_bytes(),
>>> )
"""
# Start
_check_actionable_client(client, client_fn)

wrp_client = _wrap_numpy_client(client=client) if client else None
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
start_client(
server_address=server_address,
client=_wrap_numpy_client(client=client),
client_fn=client_fn,
client=wrp_client,
grpc_max_message_length=grpc_max_message_length,
root_certificates=root_certificates,
transport=transport,
Expand Down
26 changes: 19 additions & 7 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
get_server_message_from_task_ins,
wrap_client_message_in_task_res,
)
from flwr.client.numpy_client_wrapper import to_client
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn, ClientLike
from flwr.common import serde
from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
Expand All @@ -38,13 +40,13 @@ class UnknownServerMessage(Exception):
"""Exception indicating that the received message is unknown."""


def handle(client: Client, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]:
def handle(client_fn: ClientFn, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]:
"""Handle incoming TaskIns from the server.

Parameters
----------
client : Client
The Client instance provided by the user.
client_fn : ClientFn
A callable that instantiates a Client.
task_ins: TaskIns
The task instruction coming from the server, to be processed by the client.

Expand All @@ -61,6 +63,9 @@ def handle(client: Client, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]:
"""
server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False)
if server_msg is None:
# Instantiate the client
client_like: ClientLike = client_fn("-1")
client = to_client(client_like)
# Secure Aggregation
if task_ins.task.HasField("sa") and isinstance(
client, SecureAggregationHandler
Expand All @@ -79,20 +84,22 @@ def handle(client: Client, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]:
)
return task_res, 0, True
raise NotImplementedError()
client_msg, sleep_duration, keep_going = handle_legacy_message(client, server_msg)
client_msg, sleep_duration, keep_going = handle_legacy_message(
client_fn, server_msg
)
task_res = wrap_client_message_in_task_res(client_msg)
return task_res, sleep_duration, keep_going


def handle_legacy_message(
client: Client, server_msg: ServerMessage
client_fn: ClientFn, server_msg: ServerMessage
) -> Tuple[ClientMessage, int, bool]:
"""Handle incoming messages from the server.

Parameters
----------
client : Client
The Client instance provided by the user.
client_fn : ClientFn
A callable that instantiates a Client.
server_msg: ServerMessage
The message coming from the server, to be processed by the client.

Expand All @@ -111,6 +118,11 @@ def handle_legacy_message(
if field == "reconnect_ins":
disconnect_msg, sleep_duration = _reconnect(server_msg.reconnect_ins)
return disconnect_msg, sleep_duration, False

# Instantiate the client
client_like: ClientLike = client_fn("-1")
client = to_client(client_like)
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
# Execute task
if field == "get_properties_ins":
return _get_properties(client, server_msg.get_properties_ins), 0, True
if field == "get_parameters_ins":
Expand Down
12 changes: 10 additions & 2 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import uuid

from flwr.client import Client
from flwr.client.typing import ClientFn
from flwr.common import (
EvaluateIns,
EvaluateRes,
Expand Down Expand Up @@ -103,6 +104,13 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
)


def _get_client_fn(client: Client) -> ClientFn:
def client_fn(cid: str) -> Client: # pylint: disable=unused-argument
return client

return client_fn


def test_client_without_get_properties() -> None:
"""Test client implementing get_properties."""
# Prepare
Expand All @@ -123,7 +131,7 @@ def test_client_without_get_properties() -> None:

# Execute
task_res, actual_sleep_duration, actual_keep_going = handle(
client=client, task_ins=task_ins
client_fn=_get_client_fn(client), task_ins=task_ins
)

if not task_res.HasField("task"):
Expand Down Expand Up @@ -186,7 +194,7 @@ def test_client_with_get_properties() -> None:

# Execute
task_res, actual_sleep_duration, actual_keep_going = handle(
client=client, task_ins=task_ins
client_fn=_get_client_fn(client), task_ins=task_ins
)

if not task_res.HasField("task"):
Expand Down