Skip to content

Commit

Permalink
refactor(framework) Remove partition_id from Context (#3792)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored and chongshenng committed Jul 16, 2024
1 parent 889eadf commit 76244be
Show file tree
Hide file tree
Showing 11 changed files with 68 additions and 43 deletions.
2 changes: 0 additions & 2 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ def _on_backoff(retry_state: RetryState) -> None:
node_state = NodeState(
node_id=-1,
node_config={},
partition_id=None,
)
else:
# Call create_node fn to register node
Expand All @@ -360,7 +359,6 @@ def _on_backoff(retry_state: RetryState) -> None:
node_state = NodeState(
node_id=node_id,
node_config=node_config,
partition_id=None,
)

app_state_tracker.register_signal_handler()
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ class NodeState:
"""State of a node where client nodes execute runs."""

def __init__(
self, node_id: int, node_config: Dict[str, str], partition_id: Optional[int]
self,
node_id: int,
node_config: Dict[str, str],
) -> None:
self.node_id = node_id
self.node_config = node_config
self.run_infos: Dict[int, RunInfo] = {}
self._partition_id = partition_id

def register_context(
self,
Expand All @@ -59,7 +60,6 @@ def register_context(
node_config=self.node_config,
state=RecordSet(),
run_config=initial_run_config.copy(),
partition_id=self._partition_id,
),
)

Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None:
expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}

# NodeState
node_state = NodeState(node_id=0, node_config={}, partition_id=None)
node_state = NodeState(node_id=0, node_config={})

for task in tasks:
run_id = task.run_id
Expand Down
3 changes: 3 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
FAB_CONFIG_FILE = "pyproject.toml"
FLWR_HOME = "FLWR_HOME"

# Constants entries in Node config for Simulation
PARTITION_ID_KEY = "partition-id"
NUM_PARTITIONS_KEY = "num-partitions"

GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
Expand Down
9 changes: 1 addition & 8 deletions src/py/flwr/common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict

from .record import RecordSet

Expand All @@ -43,28 +43,21 @@ class Context:
A config (key/value mapping) held by the entity in a given run and that will
stay local. It can be used at any point during the lifecycle of this entity
(e.g. across multiple rounds)
partition_id : Optional[int] (default: None)
An index that specifies the data partition that the ClientApp using this Context
object should make use of. Setting this attribute is better suited for
simulation or proto typing setups.
"""

node_id: int
node_config: Dict[str, str]
state: RecordSet
run_config: Dict[str, str]
partition_id: Optional[int]

def __init__( # pylint: disable=too-many-arguments
self,
node_id: int,
node_config: Dict[str, str],
state: RecordSet,
run_config: Dict[str, str],
partition_id: Optional[int] = None,
) -> None:
self.node_id = node_id
self.node_config = node_config
self.state = state
self.run_config = run_config
self.partition_id = partition_id
3 changes: 2 additions & 1 deletion src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ray

from flwr.client.client_app import ClientApp
from flwr.common.constant import PARTITION_ID_KEY
from flwr.common.context import Context
from flwr.common.logger import log
from flwr.common.message import Message
Expand Down Expand Up @@ -168,7 +169,7 @@ def process_message(
Return output message and updated context.
"""
partition_id = context.partition_id
partition_id = context.node_config[PARTITION_ID_KEY]

try:
# Submit a task to the pool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from flwr.client import Client, NumPyClient
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.node_state import NodeState
from flwr.common import (
DEFAULT_TTL,
Config,
Expand All @@ -32,9 +33,9 @@
Message,
MessageTypeLegacy,
Metadata,
RecordSet,
Scalar,
)
from flwr.common.constant import PARTITION_ID_KEY
from flwr.common.object_ref import load_app
from flwr.common.recordset_compat import getpropertiesins_to_recordset
from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig
Expand Down Expand Up @@ -101,12 +102,13 @@ def _create_message_and_context() -> Tuple[Message, Context, float]:

# Construct a Message
mult_factor = 2024
run_id = 0
getproperties_ins = GetPropertiesIns(config={"factor": mult_factor})
recordset = getpropertiesins_to_recordset(getproperties_ins)
message = Message(
content=recordset,
metadata=Metadata(
run_id=0,
run_id=run_id,
message_id="",
group_id="",
src_node_id=0,
Expand All @@ -117,8 +119,10 @@ def _create_message_and_context() -> Tuple[Message, Context, float]:
),
)

# Construct emtpy Context
context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={})
# Construct NodeState and retrieve context
node_state = NodeState(node_id=run_id, node_config={PARTITION_ID_KEY: str(0)})
node_state.register_context(run_id=run_id)
context = node_state.retrieve_context(run_id=run_id)

# Expected output
expected_output = pi * mult_factor
Expand Down
17 changes: 13 additions & 4 deletions src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@

from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
from flwr.client.node_state import NodeState
from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode
from flwr.common.constant import (
NUM_PARTITIONS_KEY,
PARTITION_ID_KEY,
PING_MAX_INTERVAL,
ErrorCode,
)
from flwr.common.logger import log
from flwr.common.message import Error
from flwr.common.object_ref import load_app
Expand Down Expand Up @@ -73,7 +78,7 @@ def worker(
task_ins: TaskIns = taskins_queue.get(timeout=1.0)
node_id = task_ins.task.consumer.node_id

# Register and retrieve runstate
# Register and retrieve context
node_states[node_id].register_context(run_id=task_ins.run_id)
context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)

Expand Down Expand Up @@ -283,11 +288,15 @@ def start_vce(

# Construct mapping of NodeStates
node_states: Dict[int, NodeState] = {}
# Number of unique partitions
num_partitions = len(set(nodes_mapping.values()))
for node_id, partition_id in nodes_mapping.items():
node_states[node_id] = NodeState(
node_id=node_id,
node_config={"partition-id": str(partition_id)},
partition_id=None,
node_config={
PARTITION_ID_KEY: str(partition_id),
NUM_PARTITIONS_KEY: str(num_partitions),
},
)

# Load backend config
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/simulation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def update_resources(f_stop: threading.Event) -> None:
client_fn=client_fn,
node_id=node_id,
partition_id=partition_id,
num_partitions=num_clients,
actor_pool=pool,
)
initialized_server.client_manager().register(client=client_proxy)
Expand Down
18 changes: 13 additions & 5 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from flwr.client.client_app import ClientApp
from flwr.client.node_state import NodeState
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
from flwr.common.constant import MessageType, MessageTypeLegacy
from flwr.common.constant import (
NUM_PARTITIONS_KEY,
PARTITION_ID_KEY,
MessageType,
MessageTypeLegacy,
)
from flwr.common.logger import log
from flwr.common.recordset_compat import (
evaluateins_to_recordset,
Expand All @@ -43,11 +48,12 @@
class RayActorClientProxy(ClientProxy):
"""Flower client proxy which delegates work using Ray."""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
client_fn: ClientFnExt,
node_id: int,
partition_id: int,
num_partitions: int,
actor_pool: VirtualClientEngineActorPool,
):
super().__init__(cid=str(node_id))
Expand All @@ -61,8 +67,10 @@ def _load_app() -> ClientApp:
self.actor_pool = actor_pool
self.proxy_state = NodeState(
node_id=node_id,
node_config={"partition-id": str(partition_id)},
partition_id=None,
node_config={
PARTITION_ID_KEY: str(partition_id),
NUM_PARTITIONS_KEY: str(num_partitions),
},
)

def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
Expand All @@ -74,7 +82,7 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:

# Retrieve context
context = self.proxy_state.retrieve_context(run_id=run_id)
partition_id_str = context.node_config["partition-id"]
partition_id_str = context.node_config[PARTITION_ID_KEY]

try:
self.actor_pool.submit_client_job(
Expand Down
38 changes: 23 additions & 15 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from flwr.client import Client, NumPyClient
from flwr.client.client_app import ClientApp
from flwr.client.node_state import NodeState
from flwr.common import (
DEFAULT_TTL,
Config,
Expand All @@ -31,9 +32,9 @@
Message,
MessageTypeLegacy,
Metadata,
RecordSet,
Scalar,
)
from flwr.common.constant import NUM_PARTITIONS_KEY, PARTITION_ID_KEY
from flwr.common.recordset_compat import (
getpropertiesins_to_recordset,
recordset_to_getpropertiesres,
Expand Down Expand Up @@ -99,6 +100,7 @@ def create_actor_fn() -> Type[VirtualClientEngineActor]:
client_fn=get_dummy_client,
node_id=node_id,
partition_id=partition_id,
num_partitions=num_proxies,
actor_pool=pool,
)
for node_id, partition_id in mapping.items()
Expand Down Expand Up @@ -192,6 +194,17 @@ def test_cid_consistency_without_proxies() -> None:
_, pool, mapping = prep()
node_ids = list(mapping.keys())

# register node states
node_states: Dict[int, NodeState] = {}
for node_id, partition_id in mapping.items():
node_states[node_id] = NodeState(
node_id=node_id,
node_config={
PARTITION_ID_KEY: str(partition_id),
NUM_PARTITIONS_KEY: str(len(node_ids)),
},
)

getproperties_ins = _get_valid_getpropertiesins()
recordset = getpropertiesins_to_recordset(getproperties_ins)

Expand All @@ -200,11 +213,12 @@ def _load_app() -> ClientApp:

# submit all jobs (collect later)
shuffle(node_ids)
run_id = 0
for node_id in node_ids:
message = Message(
content=recordset,
metadata=Metadata(
run_id=0,
run_id=run_id,
message_id="",
group_id=str(0),
src_node_id=0,
Expand All @@ -214,26 +228,20 @@ def _load_app() -> ClientApp:
message_type=MessageTypeLegacy.GET_PROPERTIES,
),
)
# register and retrieve context
node_states[node_id].register_context(run_id=run_id)
context = node_states[node_id].retrieve_context(run_id=run_id)
partition_id_str = context.node_config[PARTITION_ID_KEY]
pool.submit_client_job(
lambda a, c_fn, j_fn, nid_, state: a.run.remote(c_fn, j_fn, nid_, state),
(
_load_app,
message,
str(node_id),
Context(
node_id=node_id,
node_config={},
state=RecordSet(),
run_config={},
partition_id=mapping[node_id],
),
),
(_load_app, message, partition_id_str, context),
)

# fetch results one at a time
shuffle(node_ids)
for node_id in node_ids:
message_out, _ = pool.get_client_result(str(node_id), timeout=None)
partition_id_str = str(mapping[node_id])
message_out, _ = pool.get_client_result(partition_id_str, timeout=None)
res = recordset_to_getpropertiesres(message_out.content)
assert node_id * pi == res.properties["result"]

Expand Down

0 comments on commit 76244be

Please sign in to comment.