Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple client instantiation and task execution for simulation #2331

Merged
merged 15 commits into from
Sep 12, 2023
2 changes: 2 additions & 0 deletions doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

- **General updates to baselines** ([#2301](https:/adap/flower/pull/2301).[#2305](https:/adap/flower/pull/2305), [#2307](https:/adap/flower/pull/2307), [#2327](https:/adap/flower/pull/2327))

- **General updates to the simulation engine** ([#2331](https:/adap/flower/pull/2331))

- **General improvements** ([#2309](https:/adap/flower/pull/2309), [#2310](https:/adap/flower/pull/2310), [2313](https:/adap/flower/pull/2313), [#2316](https:/adap/flower/pull/2316), [2317](https:/adap/flower/pull/2317),[#2349](https:/adap/flower/pull/2349))

Flower received many improvements under the hood, too many to list here.
Expand Down
26 changes: 19 additions & 7 deletions src/py/flwr/simulation/ray_transport/ray_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@
from ray.util.actor_pool import ActorPool

from flwr import common
from flwr.client import Client, ClientFn, to_client
from flwr.common.logger import log

# All possible returns by a client
ClientRes = Union[
common.GetPropertiesRes, common.GetParametersRes, common.FitRes, common.EvaluateRes
]
# A function to be executed by a client to obtain some results
ClientJobFn = Callable[[Client], ClientRes]


class ClientException(Exception):
Expand All @@ -51,13 +54,22 @@ def terminate(self) -> None:
log(WARNING, "Manually terminating %s}", self.__class__.__name__)
ray.actor.exit_actor()

def run(self, job_fn: Callable[[], ClientRes], cid: str) -> Tuple[str, ClientRes]:
def run(
self,
client_fn: ClientFn,
job_fn: ClientJobFn,
cid: str,
) -> Tuple[str, ClientRes]:
"""Run a client workload."""
# Execute tasks and return result
# return also cid which is needed to ensure results
# from the pool are correctly assigned to each ClientProxy
try:
job_results = job_fn()
# Instantiate client
client_like = client_fn(cid)
client = to_client(client_like=client_like)
# Run client job
job_results = job_fn(client)
except Exception as ex:
client_trace = traceback.format_exc()
message = (
Expand Down Expand Up @@ -219,16 +231,16 @@ def add_actors_to_pool(self, num_actors: int) -> None:
self._idle_actors.extend(new_actors)
self.num_actors += num_actors

def submit(self, fn: Any, value: Tuple[Callable[[], ClientRes], str]) -> None:
def submit(self, fn: Any, value: Tuple[ClientFn, ClientJobFn, str]) -> None:
"""Take idle actor and assign it a client workload.

Submit a job to an actor by first removing it from the list of idle actors, then
check if this actor was flagged to be removed from the pool
"""
job_fn, cid = value
client_fn, job_fn, cid = value
actor = self._idle_actors.pop()
if self._check_and_remove_actor_from_pool(actor):
future = fn(actor, job_fn, cid)
future = fn(actor, client_fn, job_fn, cid)
future_key = tuple(future) if isinstance(future, List) else future
self._future_to_actor[future_key] = (self._next_task_index, actor, cid)
self._next_task_index += 1
Expand All @@ -237,10 +249,10 @@ def submit(self, fn: Any, value: Tuple[Callable[[], ClientRes], str]) -> None:
self._cid_to_future[cid]["future"] = future_key

def submit_client_job(
self, actor_fn: Any, job: Tuple[Callable[[], ClientRes], str]
self, actor_fn: Any, job: Tuple[ClientFn, ClientJobFn, str]
) -> None:
"""Submit a job while tracking client ids."""
_, cid = job
_, _, cid = job

# We need to put this behind a lock since .submit() involves
# removing and adding elements from a dictionary. Which creates
Expand Down
22 changes: 9 additions & 13 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import traceback
from logging import ERROR
from typing import Callable, Dict, Optional, cast
from typing import Dict, Optional, cast

import ray

Expand All @@ -32,6 +32,7 @@
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from flwr.simulation.ray_transport.ray_actor import (
ClientJobFn,
ClientRes,
VirtualClientEngineActorPool,
)
Expand Down Expand Up @@ -128,12 +129,11 @@ def __init__(
self.client_fn = client_fn
self.actor_pool = actor_pool

def _submit_job(
self, job_fn: Callable[[], ClientRes], timeout: Optional[float]
) -> ClientRes:
def _submit_job(self, job_fn: ClientJobFn, timeout: Optional[float]) -> ClientRes:
try:
self.actor_pool.submit_client_job(
lambda a, v, cid: a.run.remote(v, cid), (job_fn, self.cid)
lambda a, c_fn, j_fn, cid: a.run.remote(c_fn, j_fn, cid),
(self.client_fn, job_fn, self.cid),
)
res = self.actor_pool.get_client_result(self.cid, timeout)

Expand All @@ -153,8 +153,7 @@ def get_properties(
) -> common.GetPropertiesRes:
"""Return client's properties."""

def get_properties() -> common.GetPropertiesRes:
client: Client = _create_client(self.client_fn, self.cid)
def get_properties(client: Client) -> common.GetPropertiesRes:
return maybe_call_get_properties(
client=client,
get_properties_ins=ins,
Expand All @@ -172,8 +171,7 @@ def get_parameters(
) -> common.GetParametersRes:
"""Return the current local model parameters."""

def get_parameters() -> common.GetParametersRes:
client: Client = _create_client(self.client_fn, self.cid)
def get_parameters(client: Client) -> common.GetParametersRes:
return maybe_call_get_parameters(
client=client,
get_parameters_ins=ins,
Expand All @@ -189,8 +187,7 @@ def get_parameters() -> common.GetParametersRes:
def fit(self, ins: common.FitIns, timeout: Optional[float]) -> common.FitRes:
"""Train model parameters on the locally held dataset."""

def fit() -> common.FitRes:
client: Client = _create_client(self.client_fn, self.cid)
def fit(client: Client) -> common.FitRes:
return maybe_call_fit(
client=client,
fit_ins=ins,
Expand All @@ -208,8 +205,7 @@ def evaluate(
) -> common.EvaluateRes:
"""Evaluate model parameters on the locally held dataset."""

def evaluate() -> common.EvaluateRes:
client: Client = _create_client(self.client_fn, self.cid)
def evaluate(client: Client) -> common.EvaluateRes:
return maybe_call_evaluate(
client=client,
evaluate_ins=ins,
Expand Down
34 changes: 24 additions & 10 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

from math import pi
from random import shuffle
from typing import Callable, List, Tuple, Type, cast
from typing import List, Tuple, Type, cast

import ray

from flwr.client import NumPyClient
from flwr.client import Client, NumPyClient
from flwr.common import Code, GetPropertiesRes, Status
from flwr.simulation.ray_transport.ray_actor import (
ClientJobFn,
ClientRes,
DefaultActor,
VirtualClientEngineActor,
Expand All @@ -32,11 +33,23 @@
from flwr.simulation.ray_transport.ray_client_proxy import RayActorClientProxy


class DummyClient(NumPyClient):
"""A dummy NumPyClient for tests."""

def __init__(self, cid: str) -> None:
self.cid = int(cid)


def get_dummy_client(cid: str) -> DummyClient:
"""Return a DummyClient."""
return DummyClient(cid)


# A dummy workload
def job_fn(cid: str) -> Callable[[], ClientRes]: # pragma: no cover
def job_fn(cid: str) -> ClientJobFn: # pragma: no cover
"""Construct a simple job with cid dependency."""

def cid_times_pi() -> ClientRes:
def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argument
result = int(cid) * pi

# now let's convert it to a GetPropertiesRes response
Expand All @@ -63,14 +76,11 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]:
client_resources=client_resources,
)

def dummy_client(cid: str) -> NumPyClient: # pylint: disable=unused-argument
return NumPyClient()

# Create 373 client proxies
num_proxies = 373 # a prime number
proxies = [
RayActorClientProxy(
client_fn=dummy_client,
client_fn=get_dummy_client,
cid=str(cid),
actor_pool=pool,
)
Expand Down Expand Up @@ -110,7 +120,8 @@ def test_cid_consistency_all_submit_first() -> None:
for prox in proxies:
job = job_fn(prox.cid)
prox.actor_pool.submit_client_job(
lambda a, v, cid: a.run.remote(v, cid), (job, prox.cid)
lambda a, c_fn, j_fn, cid: a.run.remote(c_fn, j_fn, cid),
(prox.client_fn, job, prox.cid),
)

# fetch results one at a time
Expand All @@ -133,7 +144,10 @@ def test_cid_consistency_without_proxies() -> None:
shuffle(cids)
for cid in cids:
job = job_fn(cid)
pool.submit_client_job(lambda a, v, cid_: a.run.remote(v, cid_), (job, cid))
pool.submit_client_job(
lambda a, c_fn, j_fn, cid_: a.run.remote(c_fn, j_fn, cid_),
(get_dummy_client, job, cid),
)

# fetch results one at a time
shuffle(cids)
Expand Down
Loading