Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Introduce new client_fn signature passing the Context #3779

Merged
merged 14 commits into from
Jul 13, 2024
7 changes: 3 additions & 4 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from flwr.client.client import Client
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.typing import ClientFnExt
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
from flwr.common.address import parse_address
from flwr.common.constant import (
MISSING_EXTRA_REST,
Expand Down Expand Up @@ -138,7 +138,7 @@ class `flwr.client.Client` (default: None)

Starting an SSL-enabled gRPC client using system certificates:

>>> def client_fn(node_id: int, partition_id: Optional[int]):
>>> def client_fn(context: Context):
>>> return FlowerClient()
>>>
>>> start_client(
Expand Down Expand Up @@ -253,8 +253,7 @@ class `flwr.client.Client` (default: None)
if client_fn is None:
# Wrap `Client` instance in `client_fn`
def single_client_factory(
node_id: int, # pylint: disable=unused-argument
partition_id: Optional[int], # pylint: disable=unused-argument
context: Context, # pylint: disable=unused-argument
) -> Client:
if client is None: # Added this to keep mypy happy
raise ValueError(
Expand Down
16 changes: 9 additions & 7 deletions src/py/flwr/client/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,21 @@

def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
client_fn_args = inspect.signature(client_fn).parameters
first_arg = list(client_fn_args.keys())[0]

if not all(key in client_fn_args for key in ["node_id", "partition_id"]):
if len(client_fn_args) != 1 or client_fn_args[first_arg].annotation is not Context:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first arg might not have a Context annotation - we want users to allow def client_fn(context):, right?

Maybe we should just check the name: if the name is cid, or if it has a type annotation str, we use compat mode. Otherwise we assume Context.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should in the case where we have a type annotation, but it's neither str or Context?

warn_deprecated_feature(
"`client_fn` now expects a signature `def client_fn(node_id: int, "
"partition_id: Optional[int])`.\nYou provided `client_fn` with signature: "
"`client_fn` now expects a signature `def client_fn(context: Context)`."
"\nYou provided `client_fn` with signature: "
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
f"{dict(client_fn_args.items())}"
)

# Wrap depcreated client_fn inside a function with the expected signature
def adaptor_fn(
node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
) -> Client:
return client_fn(str(partition_id)) # type: ignore
def adaptor_fn(context: Context) -> Client: # pylint: disable=unused-argument
# if patition-id is defined, pass it. Else pass node_id that should always
# be defined during Context init.
cid = context.node_config.get("partition-id", context.node_id)
return client_fn(str(cid)) # type: ignore

return adaptor_fn

Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def handle_legacy_message_from_msgtype(
client_fn: ClientFnExt, message: Message, context: Context
) -> Message:
"""Handle legacy message in the inner most mod."""
client = client_fn(message.metadata.dst_node_id, context.partition_id)
client = client_fn(context)

# Check if NumPyClient is returend
if isinstance(client, NumPyClient):
Expand Down
6 changes: 2 additions & 4 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest
import uuid
from copy import copy
from typing import List, Optional
from typing import List

from flwr.client import Client
from flwr.client.typing import ClientFnExt
Expand Down Expand Up @@ -114,9 +114,7 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:


def _get_client_fn(client: Client) -> ClientFnExt:
def client_fn(
node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
) -> Client:
def client_fn(contex: Context) -> Client: # pylint: disable=unused-argument
return client

return client_fn
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/client/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""Custom types for Flower clients."""


from typing import Callable, Optional
from typing import Callable

from flwr.common import Context, Message

from .client import Client as Client

# Compatibility
ClientFn = Callable[[str], Client]
ClientFnExt = Callable[[int, Optional[int]], Client]
ClientFnExt = Callable[[Context], Client]

ClientAppCallable = Callable[[Message, Context], Message]
Mod = Callable[[Message, Context, ClientAppCallable], Message]
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]:
return {"result": result}


def get_dummy_client(
node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
) -> Client:
def get_dummy_client(context: Context) -> Client: # pylint: disable=unused-argument
"""Return a DummyClient converted to Client type."""
return DummyClient().to_client()

Expand Down
16 changes: 9 additions & 7 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ def __init__(
super().__init__(cid=str(node_id))
self.node_id = node_id
self.partition_id = partition_id
print(node_id, partition_id)
jafermarq marked this conversation as resolved.
Show resolved Hide resolved

def _load_app() -> ClientApp:
return ClientApp(client_fn=client_fn)

self.app_fn = _load_app
self.actor_pool = actor_pool
self.proxy_state = NodeState(
node_id=node_id, node_config={}, partition_id=self.partition_id
node_id=node_id, node_config={}, partition_id=partition_id
)

def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
Expand All @@ -70,18 +71,19 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
# Register state
self.proxy_state.register_context(run_id=run_id)

# Retrieve state
state = self.proxy_state.retrieve_context(run_id=run_id)
# Retrieve context
context = self.proxy_state.retrieve_context(run_id=run_id)
partition_id = str(context.partition_id)
jafermarq marked this conversation as resolved.
Show resolved Hide resolved

try:
self.actor_pool.submit_client_job(
lambda a, a_fn, mssg, partition_id, state: a.run.remote(
a_fn, mssg, partition_id, state
lambda a, a_fn, mssg, partition_id, context: a.run.remote(
a_fn, mssg, partition_id, context
),
(self.app_fn, message, str(self.partition_id), state),
(self.app_fn, message, partition_id, context),
)
out_mssg, updated_context = self.actor_pool.get_client_result(
str(self.partition_id), timeout
partition_id, timeout
)

# Update state
Expand Down
32 changes: 17 additions & 15 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from math import pi
from random import shuffle
from typing import Dict, List, Optional, Tuple, Type
from typing import Dict, List, Tuple, Type

import ray

Expand All @@ -39,7 +39,10 @@
recordset_to_getpropertiesres,
)
from flwr.common.recordset_compat_test import _get_valid_getpropertiesins
from flwr.simulation.app import _create_node_id_to_partition_mapping
from flwr.simulation.app import (
NodeToPartitionMapping,
_create_node_id_to_partition_mapping,
)
from flwr.simulation.ray_transport.ray_actor import (
ClientAppActor,
VirtualClientEngineActor,
Expand All @@ -65,16 +68,16 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]:
return {"result": result}


def get_dummy_client(
node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
) -> Client:
def get_dummy_client(context: Context) -> Client:
"""Return a DummyClient converted to Client type."""
return DummyClient(node_id).to_client()
return DummyClient(context.node_id).to_client()


def prep(
actor_type: Type[VirtualClientEngineActor] = ClientAppActor,
) -> Tuple[List[RayActorClientProxy], VirtualClientEngineActorPool]: # pragma: no cover
) -> Tuple[
List[RayActorClientProxy], VirtualClientEngineActorPool, NodeToPartitionMapping
]: # pragma: no cover
"""Prepare ClientProxies and pool for tests."""
client_resources = {"num_cpus": 1, "num_gpus": 0.0}

Expand All @@ -101,15 +104,15 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]:
for node_id, partition_id in mapping.items()
]

return proxies, pool
return proxies, pool, mapping


def test_cid_consistency_one_at_a_time() -> None:
"""Test that ClientProxies get the result of client job they submit.

Submit one job and waits for completion. Then submits the next and so on
"""
proxies, _ = prep()
proxies, _, _ = prep()

getproperties_ins = _get_valid_getpropertiesins()
recordset = getpropertiesins_to_recordset(getproperties_ins)
Expand Down Expand Up @@ -139,7 +142,7 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None:
All jobs are submitted at the same time. Then fetched one at a time. This also tests
NodeState (at each Proxy) and RunState basic functionality.
"""
proxies, _ = prep()
proxies, _, _ = prep()
run_id = 0

getproperties_ins = _get_valid_getpropertiesins()
Expand Down Expand Up @@ -186,9 +189,8 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None:

def test_cid_consistency_without_proxies() -> None:
"""Test cid consistency of jobs submitted/retrieved to/from pool w/o ClientProxy."""
proxies, pool = prep()
num_clients = len(proxies)
node_ids = list(range(num_clients))
_, pool, mapping = prep()
node_ids = list(mapping.keys())

getproperties_ins = _get_valid_getpropertiesins()
recordset = getpropertiesins_to_recordset(getproperties_ins)
Expand Down Expand Up @@ -219,11 +221,11 @@ def _load_app() -> ClientApp:
message,
str(node_id),
Context(
node_id=0,
node_id=node_id,
node_config={},
state=RecordSet(),
run_config={},
partition_id=node_id,
partition_id=mapping[node_id],
),
),
)
Expand Down