From dd1ef0d25486cefa4cf834648093692fcb27cca9 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 21 Sep 2023 19:59:41 +0100 Subject: [PATCH] update driver.py --- examples/secaggplus-mt/driver.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py index c168edf070a..4e0a53ed1c9 100644 --- a/examples/secaggplus-mt/driver.py +++ b/examples/secaggplus-mt/driver.py @@ -23,7 +23,7 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task: task_pb2.TaskIns( task_id="", # Do not set, will be created and set by the DriverAPI group_id="", - workload_id="", + workload_id=workload_id, task=merge( task, task_pb2.Task( @@ -84,8 +84,14 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK driver.connect() +create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload( + req=driver_pb2.CreateWorkloadRequest() +) # -------------------------------------------------------------------------- Driver SDK +workload_id = create_workload_res.workload_id +print(f"Created workload id {workload_id}") + history = History() for server_round in range(num_rounds): print(f"Commencing server round {server_round + 1}") @@ -113,7 +119,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # loop and wait until enough client nodes are available. while True: # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest() + get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id) # ---------------------------------------------------------------------- Driver SDK get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( @@ -121,7 +127,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: ) # ---------------------------------------------------------------------- Driver SDK - all_node_ids: List[int] = get_nodes_res.node_ids + all_node_ids: List[int] = [node.node_id for node in get_nodes_res.nodes] if len(all_node_ids) >= num_client_nodes_per_round: # Sample client nodes