Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add client_fn argument to start_client #2303

Merged
merged 18 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

### What's new?

- **Unify client API** ([#2303](https:/adap/flower/pull/2303))

Using the `client_fn`, Flower clients can interchangeably run as standalone processes (i.e. via `start_client`) or in simulation (i.e. via `start_simulation`) without requiring changes to how the client class is defined and instantiated.

- **General updates to baselines** ([#2301](https:/adap/flower/pull/2301).[#2305](https:/adap/flower/pull/2305), [#2307](https:/adap/flower/pull/2307), [#2327](https:/adap/flower/pull/2327))

- **General improvements** ([#2309](https:/adap/flower/pull/2309), [#2310](https:/adap/flower/pull/2310), [2313](https:/adap/flower/pull/2313), [#2316](https:/adap/flower/pull/2316), [2317](https:/adap/flower/pull/2317))
Expand Down
5 changes: 0 additions & 5 deletions src/py/flwr/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,16 @@
# ==============================================================================
"""Flower client."""


from .app import ClientLike as ClientLike
from .app import run_client as run_client
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
from .app import start_client as start_client
from .app import start_numpy_client as start_numpy_client
from .app import to_client as to_client
from .client import Client as Client
from .numpy_client import NumPyClient as NumPyClient

__all__ = [
"Client",
"ClientLike",
"NumPyClient",
"run_client",
"start_client",
"start_numpy_client",
"to_client",
]
236 changes: 66 additions & 170 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,10 @@
import sys
import time
from logging import INFO
from typing import Callable, Dict, Optional, Union

from flwr.common import (
GRPC_MAX_MESSAGE_LENGTH,
EventType,
event,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from typing import Callable, Optional, Union

from flwr.client.typing import ClientFn, ClientLike
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
from flwr.common.constant import (
MISSING_EXTRA_REST,
Expand All @@ -36,64 +31,33 @@
TRANSPORT_TYPES,
)
from flwr.common.logger import log
from flwr.common.typing import (
Code,
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetParametersIns,
GetParametersRes,
GetPropertiesIns,
GetPropertiesRes,
NDArrays,
Status,
)

from .client import Client
from .grpc_client.connection import grpc_connection
from .grpc_rere_client.connection import grpc_request_response
from .message_handler.message_handler import handle
from .numpy_client import NumPyClient
from .numpy_client import has_evaluate as numpyclient_has_evaluate
from .numpy_client import has_fit as numpyclient_has_fit
from .numpy_client import has_get_parameters as numpyclient_has_get_parameters
from .numpy_client import has_get_properties as numpyclient_has_get_properties

EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT = """
NumPyClient.fit did not return a tuple with 3 elements.
The returned values should have the following type signature:

Tuple[NDArrays, int, Dict[str, Scalar]]

Example
-------
from .numpy_client_wrapper import _wrap_numpy_client

model.get_weights(), 10, {"accuracy": 0.95}

"""

EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE = """
NumPyClient.evaluate did not return a tuple with 3 elements.
The returned values should have the following type signature:

Tuple[float, int, Dict[str, Scalar]]

Example
-------

0.5, 10, {"accuracy": 0.95}

"""
def _check_actionable_client(
client: Optional[ClientLike], client_fn: Optional[ClientFn]
) -> None:
if client_fn is None and client is None:
raise Exception("Both `client_fn` and `client` are `None`, but one is required")

ClientLike = Union[Client, NumPyClient]
if client_fn is not None and client is not None:
raise Exception(
"Both `client_fn` and `client` are provided, but only one is allowed"
)


# pylint: disable=import-outside-toplevel,too-many-locals,too-many-branches
# pylint: disable=too-many-statements
def start_client(
*,
server_address: str,
client: Client,
client_fn: Optional[ClientFn] = None,
client: Optional[ClientLike] = None,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[Union[bytes, str]] = None,
transport: Optional[str] = None,
Expand All @@ -106,9 +70,11 @@ def start_client(
The IPv4 or IPv6 address of the server. If the Flower
server runs on the same machine on port 8080, then `server_address`
would be `"[::]:8080"`.
client : flwr.client.Client
client_fn : Optional[ClientFn]
A callable that instantiates a Client. (default: None)
client : Optional[flwr.client.Client]
An implementation of the abstract base
class `flwr.client.Client`.
class `flwr.client.Client`. (default: None)
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
grpc_max_message_length : int (default: 536_870_912, this equals 512MB)
The maximum length of gRPC messages that can be exchanged with the
Flower server. The default should be sufficient for most models.
Expand All @@ -130,22 +96,43 @@ class `flwr.client.Client`.
--------
Starting a gRPC client with an insecure server connection:

>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_client(
>>> server_address=localhost:8080,
>>> client=FlowerClient(),
>>> client_fn=client_fn,
>>> )

Starting an SSL-enabled gRPC client:

>>> from pathlib import Path
>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_client(
>>> server_address=localhost:8080,
>>> client=FlowerClient(),
>>> client_fn=client_fn,
>>> root_certificates=Path("/crts/root.pem").read_bytes(),
>>> )
"""
event(EventType.START_CLIENT_ENTER)

_check_actionable_client(client, client_fn)

if client_fn is None:
# Wrap `Client` instance in `client_fn`
def single_client_factory(
cid: str, # pylint: disable=unused-argument
) -> ClientLike:
if client is None: # Added this to keep mypy happy
raise Exception(
"Both `client_fn` and `client` are `None`, but one is required"
)
return client # Always return the same instance

client_fn = single_client_factory

# Parse IP address
parsed_address = parse_address(server_address)
if not parsed_address:
Expand Down Expand Up @@ -196,7 +183,7 @@ class `flwr.client.Client`.
if task_ins is None:
time.sleep(3) # Wait for 3s before asking again
continue
task_res, sleep_duration, keep_going = handle(client, task_ins)
task_res, sleep_duration, keep_going = handle(client_fn, task_ins)
send(task_res)
if not keep_going:
break
Expand All @@ -222,7 +209,8 @@ class `flwr.client.Client`.
def start_numpy_client(
*,
server_address: str,
client: NumPyClient,
client_fn: Optional[Callable[[str], NumPyClient]] = None,
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
client: Optional[NumPyClient] = None,
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[bytes] = None,
transport: Optional[str] = None,
Expand All @@ -235,7 +223,9 @@ def start_numpy_client(
The IPv4 or IPv6 address of the server. If the Flower server runs on
the same machine on port 8080, then `server_address` would be
`"[::]:8080"`.
client : flwr.client.NumPyClient
client_fn : Optional[Callable[[str], NumPyClient]]
A callable that instantiates a NumPyClient. (default: None)
client : Optional[flwr.client.NumPyClient]
An implementation of the abstract base class `flwr.client.NumPyClient`.
grpc_max_message_length : int (default: 536_870_912, this equals 512MB)
The maximum length of gRPC messages that can be exchanged with the
Expand All @@ -258,134 +248,40 @@ def start_numpy_client(
--------
Starting a client with an insecure server connection:

>>> start_client(
>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_numpy_client(
>>> server_address=localhost:8080,
>>> client=FlowerClient(),
>>> client_fn=client_fn,
>>> )

Starting a SSL-enabled client:
Starting an SSL-enabled gRPC client:

>>> from pathlib import Path
>>> start_client(
>>> def client_fn(cid: str):
>>> return FlowerClient()
>>>
>>> start_numpy_client(
>>> server_address=localhost:8080,
>>> client=FlowerClient(),
>>> client_fn=client_fn,
>>> root_certificates=Path("/crts/root.pem").read_bytes(),
>>> )
"""
# Start
_check_actionable_client(client, client_fn)

wrp_client = _wrap_numpy_client(client=client) if client else None
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
start_client(
server_address=server_address,
client=_wrap_numpy_client(client=client),
client_fn=client_fn,
client=wrp_client,
grpc_max_message_length=grpc_max_message_length,
root_certificates=root_certificates,
transport=transport,
)


def to_client(client_like: ClientLike) -> Client:
"""Take any Client-like object and return it as a Client."""
if isinstance(client_like, NumPyClient):
return _wrap_numpy_client(client=client_like)
return client_like


def _constructor(self: Client, numpy_client: NumPyClient) -> None:
self.numpy_client = numpy_client # type: ignore


def _get_properties(self: Client, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return the current client properties."""
properties = self.numpy_client.get_properties(config=ins.config) # type: ignore
return GetPropertiesRes(
status=Status(code=Code.OK, message="Success"),
properties=properties,
)


def _get_parameters(self: Client, ins: GetParametersIns) -> GetParametersRes:
"""Return the current local model parameters."""
parameters = self.numpy_client.get_parameters(config=ins.config) # type: ignore
parameters_proto = ndarrays_to_parameters(parameters)
return GetParametersRes(
status=Status(code=Code.OK, message="Success"), parameters=parameters_proto
)


def _fit(self: Client, ins: FitIns) -> FitRes:
"""Refine the provided parameters using the locally held dataset."""
# Deconstruct FitIns
parameters: NDArrays = parameters_to_ndarrays(ins.parameters)

# Train
results = self.numpy_client.fit(parameters, ins.config) # type: ignore
if not (
len(results) == 3
and isinstance(results[0], list)
and isinstance(results[1], int)
and isinstance(results[2], dict)
):
raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)

# Return FitRes
parameters_prime, num_examples, metrics = results
parameters_prime_proto = ndarrays_to_parameters(parameters_prime)
return FitRes(
status=Status(code=Code.OK, message="Success"),
parameters=parameters_prime_proto,
num_examples=num_examples,
metrics=metrics,
)


def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
"""Evaluate the provided parameters using the locally held dataset."""
parameters: NDArrays = parameters_to_ndarrays(ins.parameters)

results = self.numpy_client.evaluate(parameters, ins.config) # type: ignore
if not (
len(results) == 3
and isinstance(results[0], float)
and isinstance(results[1], int)
and isinstance(results[2], dict)
):
raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)

# Return EvaluateRes
loss, num_examples, metrics = results
return EvaluateRes(
status=Status(code=Code.OK, message="Success"),
loss=loss,
num_examples=num_examples,
metrics=metrics,
)


def _wrap_numpy_client(client: NumPyClient) -> Client:
member_dict: Dict[str, Callable] = { # type: ignore
"__init__": _constructor,
}

# Add wrapper type methods (if overridden)

if numpyclient_has_get_properties(client=client):
member_dict["get_properties"] = _get_properties

if numpyclient_has_get_parameters(client=client):
member_dict["get_parameters"] = _get_parameters

if numpyclient_has_fit(client=client):
member_dict["fit"] = _fit

if numpyclient_has_evaluate(client=client):
member_dict["evaluate"] = _evaluate

# Create wrapper class
wrapper_class = type("NumPyClientWrapper", (Client,), member_dict)

# Create and return an instance of the newly created class
return wrapper_class(numpy_client=client) # type: ignore


def run_client() -> None:
"""Run Flower client."""
log(INFO, "Running Flower client...")
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/client/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from typing import Dict, Tuple

from flwr.client.numpy_client_wrapper import to_client
from flwr.client.typing import ClientLike
from flwr.common import (
Config,
EvaluateIns,
Expand All @@ -31,7 +33,7 @@
Scalar,
)

from .app import ClientLike, start_client, start_numpy_client, to_client
from .app import start_client, start_numpy_client
from .client import Client
from .numpy_client import NumPyClient

Expand Down
Loading
Loading