diff --git a/examples/app-secure-aggregation/README.md b/examples/app-secure-aggregation/README.md deleted file mode 100644 index 8e483fb2f6b..00000000000 --- a/examples/app-secure-aggregation/README.md +++ /dev/null @@ -1,99 +0,0 @@ ---- -tags: [basic, vision, fds] -dataset: [] -framework: [numpy] ---- - -# Secure aggregation with Flower (the SecAgg+ protocol) ๐Ÿงช - -> ๐Ÿงช = This example covers experimental features that might change in future versions of Flower -> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. - -The following steps describe how to use Secure Aggregation in flower, with `ClientApp` using `secaggplus_mod` and `ServerApp` using `SecAggPlusWorkflow`. - -## Preconditions - -Let's assume the following project structure: - -```bash -$ tree . -. -โ”œโ”€โ”€ client.py # Client application using `secaggplus_mod` -โ”œโ”€โ”€ server.py # Server application using `SecAggPlusWorkflow` -โ”œโ”€โ”€ workflow_with_log.py # Augmented `SecAggPlusWorkflow` -โ”œโ”€โ”€ run.sh # Quick start script -โ”œโ”€โ”€ pyproject.toml # Project dependencies (poetry) -โ””โ”€โ”€ requirements.txt # Project dependencies (pip) -``` - -## Installing dependencies - -Project dependencies (such as and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. - -### Poetry - -```shell -poetry install -poetry shell -``` - -Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: - -```shell -poetry run python3 -c "import flwr" -``` - -### pip - -Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. - -```shell -pip install -r requirements.txt -``` - -If you don't see any errors you're good to go! - -## Run the example with one command (recommended) - -```bash -./run.sh -``` - -## Run the example with the simulation engine - -```bash -flower-simulation --server-app server:app --client-app client:app --num-supernodes 5 -``` - -## Alternatively, run the example (in 7 terminal windows) - -Start the Flower Superlink in one terminal window: - -```bash -flower-superlink --insecure -``` - -Start 5 Flower `ClientApp` in 5 separate terminal windows: - -```bash -flower-client-app client:app --insecure -``` - -Start the Flower `ServerApp`: - -```bash -flower-server-app server:app --insecure --verbose -``` - -## Amend the example for practical usage - -For real-world applications, modify the `workflow` in `server.py` as follows: - -```python -workflow = fl.server.workflow.DefaultWorkflow( - fit_workflow=SecAggPlusWorkflow( - num_shares=, - reconstruction_threshold=, - ) -) -``` diff --git a/examples/app-secure-aggregation/client.py b/examples/app-secure-aggregation/client.py deleted file mode 100644 index b2fd02ec00d..00000000000 --- a/examples/app-secure-aggregation/client.py +++ /dev/null @@ -1,34 +0,0 @@ -import time - -from flwr.client import ClientApp, NumPyClient -from flwr.client.mod import secaggplus_mod -import numpy as np - - -# Define FlowerClient and client_fn -class FlowerClient(NumPyClient): - def fit(self, parameters, config): - # Instead of training and returning model parameters, - # the client directly returns [1.0, 1.0, 1.0] for demonstration purposes. - ret_vec = [np.ones(3)] - # Force a significant delay for testing purposes - if "drop" in config and config["drop"]: - print(f"Client dropped for testing purposes.") - time.sleep(8) - else: - print(f"Client uploading {ret_vec[0]}...") - return ret_vec, 1, {} - - -def client_fn(cid: str): - """Create and return an instance of Flower `Client`.""" - return FlowerClient().to_client() - - -# Flower ClientApp -app = ClientApp( - client_fn=client_fn, - mods=[ - secaggplus_mod, - ], -) diff --git a/examples/app-secure-aggregation/pyproject.toml b/examples/app-secure-aggregation/pyproject.toml deleted file mode 100644 index fb1f636d8c3..00000000000 --- a/examples/app-secure-aggregation/pyproject.toml +++ /dev/null @@ -1,14 +0,0 @@ -[build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" - -[tool.poetry] -name = "app-secure-aggregation" -version = "0.1.0" -description = "Flower Secure Aggregation example." -authors = ["The Flower Authors "] - -[tool.poetry.dependencies] -python = "^3.8" -# Mandatory dependencies -flwr = { version = "^1.8.0", extras = ["simulation"] } diff --git a/examples/app-secure-aggregation/requirements.txt b/examples/app-secure-aggregation/requirements.txt deleted file mode 100644 index 2d8be098f26..00000000000 --- a/examples/app-secure-aggregation/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -flwr[simulation]>=1.8.0 diff --git a/examples/app-secure-aggregation/run.sh b/examples/app-secure-aggregation/run.sh deleted file mode 100755 index fa8dc47f26e..00000000000 --- a/examples/app-secure-aggregation/run.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -# Kill any currently running client.py processes -pkill -f 'flower-client-app' - -# Kill any currently running flower-superlink processes -pkill -f 'flower-superlink' - -# Start the flower server -echo "Starting flower server in background..." -flower-superlink --insecure > /dev/null 2>&1 & -sleep 2 - -# Number of client processes to start -N=5 # Replace with your desired value - -echo "Starting $N ClientApps in background..." - -# Start N client processes -for i in $(seq 1 $N) -do - flower-client-app --insecure client:app > /dev/null 2>&1 & - sleep 0.1 -done - -echo "Starting ServerApp..." -flower-server-app --insecure server:app --verbose - -echo "Clearing background processes..." - -# Kill any currently running client.py processes -pkill -f 'flower-client-app' - -# Kill any currently running flower-superlink processes -pkill -f 'flower-superlink' diff --git a/examples/app-secure-aggregation/server.py b/examples/app-secure-aggregation/server.py deleted file mode 100644 index ebd70045fdc..00000000000 --- a/examples/app-secure-aggregation/server.py +++ /dev/null @@ -1,45 +0,0 @@ -from flwr.common import Context -from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig -from flwr.server.strategy import FedAvg -from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow - -from workflow_with_log import SecAggPlusWorkflowWithLogs - - -# Define strategy -strategy = FedAvg( - fraction_fit=1.0, # Select all available clients - fraction_evaluate=0.0, # Disable evaluation - min_available_clients=5, -) - - -# Flower ServerApp -app = ServerApp() - - -@app.main() -def main(driver: Driver, context: Context) -> None: - # Construct the LegacyContext - context = LegacyContext( - context=context, - config=ServerConfig(num_rounds=3), - strategy=strategy, - ) - - # Create the workflow - workflow = DefaultWorkflow( - fit_workflow=SecAggPlusWorkflowWithLogs( - num_shares=3, - reconstruction_threshold=2, - timeout=5, - ) - # # For real-world applications, use the following code instead - # fit_workflow=SecAggPlusWorkflow( - # num_shares=, - # reconstruction_threshold=, - # ) - ) - - # Execute - workflow(driver, context) diff --git a/examples/flower-secure-aggregation/README.md b/examples/flower-secure-aggregation/README.md new file mode 100644 index 00000000000..9e92aed01d9 --- /dev/null +++ b/examples/flower-secure-aggregation/README.md @@ -0,0 +1,72 @@ +--- +tags: [advanced, secure_aggregation, privacy] +dataset: [CIFAR-10] +framework: [torch, torchvision] +--- + +# Secure aggregation with Flower (the SecAgg+ protocol) + +The following steps describe how to use Flower's built-in Secure Aggregation components. This example demonstrates how to apply `SecAgg+` to the same federated learning workload as in the [quickstart-pytorch](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) example. The `ServerApp` uses the [`SecAggPlusWorkflow`](https://flower.ai/docs/framework/ref-api/flwr.server.workflow.SecAggPlusWorkflow.html#secaggplusworkflow) while `ClientApp` uses the [`secaggplus_mod`](https://flower.ai/docs/framework/ref-api/flwr.client.mod.secaggplus_mod.html#flwr.client.mod.secaggplus_mod). To introduce the various steps involved in `SecAgg+`, this example introduces as a sub-class of `SecAggPlusWorkflow` the `SecAggPlusWorkflowWithLogs`. It is enabled by default, but you can disable (see later in this readme). + +## Set up the project + +### Clone the project + +Start by cloning the example project: + +```shell +git clone --depth=1 https://github.com/adap/flower.git _tmp \ + && mv _tmp/examples/flower-secure-aggregation . \ + && rm -rf _tmp && cd flower-secure-aggregation +``` + +This will create a new directory called `flower-secure-aggregation` containing the +following files: + +```shell +flower-secure-aggregation +| +โ”œโ”€โ”€ secaggexample +| โ”œโ”€โ”€ __init__.py +| โ”œโ”€โ”€ client_app.py # Defines your ClientApp +| โ”œโ”€โ”€ server_app.py # Defines your ServerApp +| โ”œโ”€โ”€ task.py # Defines your model, training and data loading +| โ””โ”€โ”€ workflow_with_log.py # Defines a workflow used when `is-demo=true` +โ”œโ”€โ”€ pyproject.toml # Project metadata like dependencies and configs +โ””โ”€โ”€ README.md +``` + +### Install dependencies and project + +Install the dependencies defined in `pyproject.toml` as well as the `secaggexample` package. + +```bash +pip install -e . +``` + +## Run the project + +You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine. + +### Run with the Simulation Engine + +```bash +flwr run . +``` + +You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example + +```bash +flwr run . --run-config num-server-rounds=5,learning-rate=0.25 +``` + +To adapt the example for a practial usage, set `is-demo=false` like shown below. You might want to adjust the `num-shares` and `reconstruction-threshold` settings to suit your requirements. You can override those via `--run-config` as well. + +```bash +flwr run . --run-config is-demo=false +``` + +### Run with the Deployment Engine + +> \[!NOTE\] +> An update to this example will show how to run this Flower project with the Deployment Engine and TLS certificates, or with Docker. diff --git a/examples/flower-secure-aggregation/pyproject.toml b/examples/flower-secure-aggregation/pyproject.toml new file mode 100644 index 00000000000..d9be719653b --- /dev/null +++ b/examples/flower-secure-aggregation/pyproject.toml @@ -0,0 +1,46 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "secaggexample" +version = "1.0.0" +description = "Secure Aggregation in Flower" +license = "Apache-2.0" +dependencies = [ + "flwr[simulation]>=1.10.0", + "flwr-datasets[vision]>=0.3.0", + "torch==2.2.1", + "torchvision==0.17.1", +] + +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.flwr.app] +publisher = "flwrlabs" + +[tool.flwr.app.components] +serverapp = "secaggexample.server_app:app" +clientapp = "secaggexample.client_app:app" + + +[tool.flwr.app.config] +num-server-rounds = 3 +fraction-evaluate = 0.5 +local-epochs = 1 +learning-rate = 0.1 +batch-size = 32 +# Parameters for the SecAgg+ protocol +num-shares = 3 +reconstruction-threshold = 2 +max-weight = 9000 +timeout = 15.0 +# Demo flag +is-demo = true + +[tool.flwr.federations] +default = "local-simulation" + +[tool.flwr.federations.local-simulation] +options.num-supernodes = 5 diff --git a/examples/flower-secure-aggregation/secaggexample/__init__.py b/examples/flower-secure-aggregation/secaggexample/__init__.py new file mode 100644 index 00000000000..366ceebfae8 --- /dev/null +++ b/examples/flower-secure-aggregation/secaggexample/__init__.py @@ -0,0 +1 @@ +"""secaggexample: A Flower with SecAgg+ app.""" diff --git a/examples/flower-secure-aggregation/secaggexample/client_app.py b/examples/flower-secure-aggregation/secaggexample/client_app.py new file mode 100644 index 00000000000..7f4fd54b98b --- /dev/null +++ b/examples/flower-secure-aggregation/secaggexample/client_app.py @@ -0,0 +1,91 @@ +"""secaggexample: A Flower with SecAgg+ app.""" + +import time + +import torch +from flwr.client import ClientApp, NumPyClient +from flwr.client.mod import secaggplus_mod +from flwr.common import Context + +from secaggexample.task import Net, get_weights, load_data, set_weights, test, train + + +# Define Flower Client +class FlowerClient(NumPyClient): + def __init__( + self, trainloader, valloader, local_epochs, learning_rate, timeout, is_demo + ): + self.net = Net() + self.trainloader = trainloader + self.valloader = valloader + self.local_epochs = local_epochs + self.lr = learning_rate + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # For demonstration purposes only + self.timeout = timeout + self.is_demo = is_demo + + def fit(self, parameters, config): + """Train the model with data of this client.""" + set_weights(self.net, parameters) + results = {} + if not self.is_demo: + results = train( + self.net, + self.trainloader, + self.valloader, + self.local_epochs, + self.lr, + self.device, + ) + ret_vec = get_weights(self.net) + + # Force a significant delay for testing purposes + if self.is_demo: + if config.get("drop", False): + print(f"Client dropped for testing purposes.") + time.sleep(self.timeout) + else: + print(f"Client uploading parameters: {ret_vec[0].flatten()[:3]}...") + return ret_vec, len(self.trainloader.dataset), results + + def evaluate(self, parameters, config): + """Evaluate the model on the data this client has.""" + set_weights(self.net, parameters) + loss, accuracy = 0.0, 0.0 + if not self.is_demo: + loss, accuracy = test(self.net, self.valloader, self.device) + return loss, len(self.valloader.dataset), {"accuracy": accuracy} + + +def client_fn(context: Context): + """Construct a Client that will be run in a ClientApp.""" + + # Read the node_config to fetch data partition associated to this node + partition_id = context.node_config["partition-id"] + num_partitions = context.node_config["num-partitions"] + + # Read run_config to fetch hyperparameters relevant to this run + batch_size = context.run_config["batch-size"] + is_demo = context.run_config["is-demo"] + trainloader, valloader = load_data( + partition_id, num_partitions, batch_size, is_demo + ) + local_epochs = context.run_config["local-epochs"] + lr = context.run_config["learning-rate"] + # For demostrations purposes only + timeout = context.run_config["timeout"] + + # Return Client instance + return FlowerClient( + trainloader, valloader, local_epochs, lr, timeout, is_demo + ).to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, + mods=[ + secaggplus_mod, + ], +) diff --git a/examples/flower-secure-aggregation/secaggexample/server_app.py b/examples/flower-secure-aggregation/secaggexample/server_app.py new file mode 100644 index 00000000000..a332ffb9eca --- /dev/null +++ b/examples/flower-secure-aggregation/secaggexample/server_app.py @@ -0,0 +1,81 @@ +"""secaggexample: A Flower with SecAgg+ app.""" + +from logging import DEBUG +from typing import List, Tuple + +from secaggexample.task import get_weights, make_net +from secaggexample.workflow_with_log import SecAggPlusWorkflowWithLogs + +from flwr.common import Context, Metrics, ndarrays_to_parameters +from flwr.common.logger import update_console_handler + +from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig +from flwr.server.strategy import FedAvg +from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +# Flower ServerApp +app = ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + + is_demo = context.run_config["is-demo"] + + # Get initial parameters + ndarrays = get_weights(make_net()) + parameters = ndarrays_to_parameters(ndarrays) + + # Define strategy + strategy = FedAvg( + # Select all available clients + fraction_fit=1.0, + # Disable evaluation in demo + fraction_evaluate=(0.0 if is_demo else context.run_config["fraction-evaluate"]), + min_available_clients=5, + evaluate_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, + ) + + # Construct the LegacyContext + num_rounds = context.run_config["num-server-rounds"] + context = LegacyContext( + context=context, + config=ServerConfig(num_rounds=num_rounds), + strategy=strategy, + ) + + # Create fit workflow + # For further information, please see: + # https://flower.ai/docs/framework/ref-api/flwr.server.workflow.SecAggPlusWorkflow.html + if is_demo: + update_console_handler(DEBUG, True, True) + fit_workflow = SecAggPlusWorkflowWithLogs( + num_shares=context.run_config["num-shares"], + reconstruction_threshold=context.run_config["reconstruction-threshold"], + max_weight=1, + timeout=context.run_config["timeout"], + ) + else: + fit_workflow = SecAggPlusWorkflow( + num_shares=context.run_config["num-shares"], + reconstruction_threshold=context.run_config["reconstruction-threshold"], + max_weight=context.run_config["max-weight"], + ) + + # Create the workflow + workflow = DefaultWorkflow(fit_workflow=fit_workflow) + + # Execute + workflow(driver, context) diff --git a/examples/flower-secure-aggregation/secaggexample/task.py b/examples/flower-secure-aggregation/secaggexample/task.py new file mode 100644 index 00000000000..e9cca8ef911 --- /dev/null +++ b/examples/flower-secure-aggregation/secaggexample/task.py @@ -0,0 +1,128 @@ +"""secaggexample: A Flower with SecAgg+ app.""" + +import random +from collections import OrderedDict +from unittest.mock import Mock + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize, ToTensor + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def make_net(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + return Net() + + +def get_weights(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) + + +fds = None # Cache FederatedDataset + + +def load_data(partition_id: int, num_partitions: int, batch_size: int, is_demo: bool): + """Load partition CIFAR10 data.""" + if is_demo: + trainloader, testloader = Mock(dataset=[0]), Mock(dataset=[0]) + return trainloader, testloader + # Only initialize `FederatedDataset` once + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="uoft-cs/cifar10", + partitioners={"train": partitioner}, + ) + partition = fds.load_partition(partition_id) + # Divide data on each node: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2, seed=42) + pytorch_transforms = Compose( + [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch["img"] = [pytorch_transforms(img) for img in batch["img"]] + return batch + + partition_train_test = partition_train_test.with_transform(apply_transforms) + trainloader = DataLoader( + partition_train_test["train"], batch_size=batch_size, shuffle=True + ) + testloader = DataLoader(partition_train_test["test"], batch_size=batch_size) + return trainloader, testloader + + +def train(net, trainloader, valloader, epochs, learning_rate, device): + """Train the model on the training set.""" + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9) + net.train() + for _ in range(epochs): + for batch in trainloader: + images = batch["img"] + labels = batch["label"] + optimizer.zero_grad() + criterion(net(images.to(device)), labels.to(device)).backward() + optimizer.step() + + val_loss, val_acc = test(net, valloader, device) + + results = { + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results + + +def test(net, testloader, device): + """Validate the model on the test set.""" + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for batch in testloader: + images = batch["img"].to(device) + labels = batch["label"].to(device) + outputs = net(images) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + loss = loss / len(testloader) + return loss, accuracy diff --git a/examples/app-secure-aggregation/workflow_with_log.py b/examples/flower-secure-aggregation/secaggexample/workflow_with_log.py similarity index 74% rename from examples/app-secure-aggregation/workflow_with_log.py rename to examples/flower-secure-aggregation/secaggexample/workflow_with_log.py index a03ff8c13b6..b2e457484de 100644 --- a/examples/app-secure-aggregation/workflow_with_log.py +++ b/examples/flower-secure-aggregation/secaggexample/workflow_with_log.py @@ -1,14 +1,18 @@ -from flwr.common import Context, log, parameters_to_ndarrays +"""secaggexample: A Flower with SecAgg+ app.""" + from logging import INFO + +from secaggexample.task import get_weights, make_net + +import flwr.common.recordset_compat as compat +from flwr.common import Context, log, parameters_to_ndarrays +from flwr.common.secure_aggregation.quantization import quantize from flwr.server import Driver, LegacyContext +from flwr.server.workflow.constant import MAIN_PARAMS_RECORD from flwr.server.workflow.secure_aggregation.secaggplus_workflow import ( SecAggPlusWorkflow, WorkflowState, ) -import numpy as np -from flwr.common.secure_aggregation.quantization import quantize -from flwr.server.workflow.constant import MAIN_PARAMS_RECORD -import flwr.common.recordset_compat as compat class SecAggPlusWorkflowWithLogs(SecAggPlusWorkflow): @@ -21,8 +25,11 @@ class SecAggPlusWorkflowWithLogs(SecAggPlusWorkflow): node_ids = [] def __call__(self, driver: Driver, context: Context) -> None: + first_3_params = get_weights(make_net())[0].flatten()[:3] _quantized = quantize( - [np.ones(3) for _ in range(5)], self.clipping_range, self.quantization_range + [first_3_params for _ in range(5)], + self.clipping_range, + self.quantization_range, ) log(INFO, "") log( @@ -31,24 +38,24 @@ def __call__(self, driver: Driver, context: Context) -> None: ) log( INFO, - "In the example, each client will upload a vector [1.0, 1.0, 1.0] instead of", + "In the example, clients will skip model training and evaluation", ) - log(INFO, "model updates for demonstration purposes.") + log(INFO, "for demonstration purposes.") log( INFO, "Client 0 is configured to drop out before uploading the masked vector.", ) log(INFO, "After quantization, the raw vectors will look like:") for i in range(1, 5): - log(INFO, "\t%s from Client %s", _quantized[i], i) + log(INFO, "\t%s... from Client %s", _quantized[i], i) log( INFO, - "Numbers are rounded to integers stochastically during the quantization", + "Numbers are rounded to integers stochastically during the quantization, ", ) - log(INFO, ", and thus entries may not be identical.") + log(INFO, "and thus vectors may not be identical.") log( INFO, - "The above raw vectors are hidden from the driver through adding masks.", + "The above raw vectors are hidden from the ServerApp through adding masks.", ) log(INFO, "") log( @@ -63,8 +70,8 @@ def __call__(self, driver: Driver, context: Context) -> None: ndarrays = parameters_to_ndarrays(parameters) log( INFO, - "Weighted average of vectors (dequantized): %s", - ndarrays[0], + "Weighted average of parameters (dequantized): %s...", + ndarrays[0].flatten()[:3], ) log( INFO, @@ -88,5 +95,9 @@ def collect_masked_vectors_stage( ret = super().collect_masked_vectors_stage(driver, context, state) for node_id in state.sampled_node_ids - state.active_node_ids: log(INFO, "Client %s dropped out.", self.node_ids.index(node_id)) - log(INFO, "Obtained sum of masked vectors: %s", state.aggregate_ndarrays[1]) + log( + INFO, + "Obtained sum of masked parameters: %s...", + state.aggregate_ndarrays[1].flatten()[:3], + ) return ret