Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update FedProx baseline #2286

Merged
merged 28 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions baselines/fedprox/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
---
title: Federated Optimization in Heterogeneous Networks
url: https://arxiv.org/abs/1812.06127
labels: [image classification, cross-device, stragglers] # please add between 4 and 10 single-word (maybe two-words) labels (e.g. "system heterogeneity", "image classification", "asynchronous", "weight sharing", "cross-silo")
dataset: [mnist] # list of datasets you include in your baseline
labels: [image classification, cross-device, stragglers]
dataset: [mnist]
---

# FedProx: Federated Optimization in Heterogeneous Networks

> Note: If you use this baseline in your work, please remember to cite the original authors of the paper as well as the Flower paper.

**Paper:** https://arxiv.org/abs/1812.06127
**Paper:** [arxiv.org/abs/1812.06127](https://arxiv.org/abs/1812.06127)

**Authors:** Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar and Virginia Smith.

Expand All @@ -34,7 +34,7 @@ dataset: [mnist] # list of datasets you include in your baseline
* A logistic regression model used in the FedProx paper for MNIST (see `models/LogisticRegression`). This is the model used by default.
* A two-layer CNN network as used in the FedAvg paper (see `models/Net`)

**Dataset:** This baseline only includes the MNIST dataset. By default it will be partitioned into 1000 clients following a pathological split where each client has examples of two (out of ten) class labels. The number of examples in each client is derived by sampling from a powerlaw distribution. The settings are as follow:
**Dataset:** This baseline only includes the MNIST dataset. By default, it will be partitioned into 1000 clients following a pathological split where each client has examples of two (out of ten) class labels. The number of examples in each client is derived by sampling from a powerlaw distribution. The settings are as follows:

| Dataset | #classes | #partitions | partitioning method | partition settings |
| :------ | :---: | :---: | :---: | :---: |
Expand All @@ -56,17 +56,10 @@ The following table shows the main hyperparameters for this baseline with their

## Environment Setup

To construct the Python environment follow these steps:
To construct the Python environment, simply run:

```bash
# install the base Poetry environment
poetry install

# activate the environment
poetry shell

# install PyTorch with GPU support. Please note this baseline is very lightweight so it can run fine on a CPU.
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
```

## Running the Experiments
Expand Down Expand Up @@ -96,7 +89,7 @@ python -m fedprox.main --config-name fedavg

## Expected results

With the following command we run both FedProx and FedAvg configurations while iterating through different values of `mu` and `stragglers_fraction`. We ran each experiment five times (this is achieved by artificially adding an extra element to the config but that it doesn't have an impact on the FL setting `'+repeat_num=range(5)'`)
With the following command, we run both FedProx and FedAvg configurations while iterating through different values of `mu` and `stragglers_fraction`. We ran each experiment five times (this is achieved by artificially adding an extra element to the config but it doesn't have an impact on the FL setting `'+repeat_num=range(5)'`)

```bash
python -m fedprox.main --multirun mu=0.0,2.0 stragglers_fraction=0.0,0.5,0.9 '+repeat_num=range(5)'
Expand Down
1 change: 1 addition & 0 deletions baselines/fedprox/fedprox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""FedProx package."""
29 changes: 12 additions & 17 deletions baselines/fedprox/fedprox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from omegaconf import DictConfig
from torch.utils.data import DataLoader

from fedprox.dataset import load_datasets
from fedprox.models import test, train


# pylint: disable=too-many-arguments
class FlowerClient(
fl.client.NumPyClient
): # pylint: disable=too-many-instance-attributes
Expand All @@ -40,19 +40,19 @@ def __init__(
self.straggler_schedule = straggler_schedule

def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
"""Returns the parameters of the current net."""
"""Return the parameters of the current net."""
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

def set_parameters(self, parameters: NDArrays) -> None:
"""Changes the parameters of the model using the given ones."""
"""Change the parameters of the model using the given ones."""
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.net.load_state_dict(state_dict, strict=True)

def fit(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
"""Implements distributed fit function for a given client."""
"""Implement distributed fit function for a given client."""
self.set_parameters(parameters)

# At each round check if the client is a straggler,
Expand Down Expand Up @@ -88,15 +88,15 @@ def fit(
self.device,
epochs=num_epochs,
learning_rate=self.learning_rate,
proximal_mu=config["proximal_mu"],
proximal_mu=float(config["proximal_mu"]),
)

return self.get_parameters({}), len(self.trainloader), {"is_straggler": False}

def evaluate(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[float, int, Dict]:
"""Implements distributed evaluation for a given client."""
"""Implement distributed evaluation for a given client."""
self.set_parameters(parameters)
loss, accuracy = test(self.net, self.valloader, self.device)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
Expand All @@ -111,10 +111,8 @@ def gen_client_fn(
learning_rate: float,
stragglers: float,
model: DictConfig,
) -> Tuple[
Callable[[str], FlowerClient], DataLoader
]: # pylint: disable=too-many-arguments
"""Generates the client function that creates the Flower Clients.
) -> Callable[[str], FlowerClient]: # pylint: disable=too-many-arguments
"""Generate the client function that creates the Flower Clients.

Parameters
----------
Expand All @@ -139,13 +137,11 @@ def gen_client_fn(

Returns
-------
Tuple[Callable[[str], FlowerClient], DataLoader]
A tuple containing the client function that creates Flower Clients and
the DataLoader that will be used for testing
Callable[[str], FlowerClient]
A client function that creates Flower Clients.
"""

# Defines a staggling schedule for each clients, i.e at which round will they
# be a straggler. This is done so at each round the proportion of staggling
# Defines a straggling schedule for each clients, i.e at which round will they
# be a straggler. This is done so at each round the proportion of straggling
# clients is respected
stragglers_mat = np.transpose(
np.random.choice(
Expand All @@ -155,7 +151,6 @@ def gen_client_fn(

def client_fn(cid: str) -> FlowerClient:
"""Create a Flower client representing a single organization."""

# Load model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = instantiate(model).to(device)
Expand Down
4 changes: 2 additions & 2 deletions baselines/fedprox/fedprox/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ model:

strategy:
_target_: flwr.server.strategy.FedProx
fraction_fit: 0.00001 # because we want the number of clients to sample on each roudn to be solely defined by min_fit_clients
fraction_fit: 0.00001 # because we want the number of clients to sample on each round to be solely defined by min_fit_clients
fraction_evaluate: 0.0
min_fit_clients: ${clients_per_round}
min_evaluate_clients: 0
min_available_clients: ${clients_per_round}
evaluate_metrics_aggregation_fn:
_target_: fedprox.strategy.weighted_average
_partial_: true # we dont' want this function to be evaluated when instantiating the strategy, we treat it as a partial and evaluate it when the strategy actuallly calls the function (in aggregate_evaluate())
_partial_: true # we dont' want this function to be evaluated when instantiating the strategy, we treat it as a partial and evaluate it when the strategy actually calls the function (in aggregate_evaluate())
proximal_mu: ${mu}
4 changes: 2 additions & 2 deletions baselines/fedprox/fedprox/conf/fedavg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ model:

strategy:
_target_: fedprox.strategy.FedAvgWithStragglerDrop #! this points to FedAvgWithStragglerDrop class in strategy.py, Note that we need the full module path (including `fedprox`)
fraction_fit: 0.00001 # because we want the number of clients to sample on each roudn to be solely defined by min_fit_clients
fraction_fit: 0.00001 # because we want the number of clients to sample on each round to be solely defined by min_fit_clients
fraction_evaluate: 0.0
min_fit_clients: ${clients_per_round}
min_available_clients: ${clients_per_round}
min_evaluate_clients: 0
evaluate_metrics_aggregation_fn:
_target_: fedprox.strategy.weighted_average
_partial_: true # we dont' want this function to be evaluated when instantiating the strategy, we treat it as a partial and evaluate it when the strategy actuallly calls the function (in aggregate_evaluate())
_partial_: true # we dont' want this function to be evaluated when instantiating the strategy, we treat it as a partial and evaluate it when the strategy actually calls the function (in aggregate_evaluate())
5 changes: 3 additions & 2 deletions baselines/fedprox/fedprox/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def load_datasets( # pylint: disable=too-many-arguments
batch_size: Optional[int] = 32,
seed: Optional[int] = 42,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""Creates the dataloaders to be fed into the model.
"""Create the dataloaders to be fed into the model.

Parameters
----------
Expand All @@ -36,7 +36,8 @@ def load_datasets( # pylint: disable=too-many-arguments
Returns
-------
Tuple[DataLoader, DataLoader, DataLoader]
The DataLoader for training, the DataLoader for validation, the DataLoader for testing.
The DataLoader for training, the DataLoader for validation, the DataLoader
for testing.
"""
print(f"Dataset partitioning config: {config}")
datasets, testset = _partition_data(
Expand Down
36 changes: 22 additions & 14 deletions baselines/fedprox/fedprox/dataset_preparation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Functions for dataset download and processing."""
from typing import List, Optional, Tuple

import numpy as np
Expand All @@ -8,7 +9,7 @@


def _download_data() -> Tuple[Dataset, Dataset]:
"""Downloads (if necessary) and returns the MNIST dataset.
"""Download (if necessary) and returns the MNIST dataset.

Returns
-------
Expand All @@ -23,24 +24,27 @@ def _download_data() -> Tuple[Dataset, Dataset]:
return trainset, testset


# pylint: disable=too-many-locals
def _partition_data(
num_clients,
iid: Optional[bool] = False,
power_law: Optional[bool] = True,
balance: Optional[bool] = False,
seed: Optional[int] = 42,
) -> Tuple[List[Dataset], Dataset]:
"""Split training set into iid or non iid partitions to simulate the
federated setting.
"""Split training set into iid or non iid partitions to simulate the federated.

setting.

Parameters
----------
num_clients : int
The number of clients that hold a part of the data
iid : bool, optional
Whether the data should be independent and identically distributed between
the clients or if the data should first be sorted by labels and distributed by chunks
to each client (used to test the convergence in a worst case scenario), by default False
the clients or if the data should first be sorted by labels and distributed
by chunks to each client (used to test the convergence in a worst case scenario)
, by default False
power_law: bool, optional
Whether to follow a power-law distribution when assigning number of samples
for each client, defaults to True
Expand All @@ -53,13 +57,14 @@ def _partition_data(
Returns
-------
Tuple[List[Dataset], Dataset]
A list of dataset for each client and a single dataset to be use for testing the model.
A list of dataset for each client and a single dataset to be use for testing
the model.
"""
trainset, testset = _download_data()

if balance:
trainset = _balance_classes(trainset, seed)

partition_size = int(len(trainset) / num_clients)
lengths = [partition_size] * num_clients

Expand Down Expand Up @@ -171,6 +176,7 @@ def _sort_by_class(
return sorted_dataset


# pylint: disable=too-many-locals, too-many-arguments
def _power_law_split(
sorted_trainset: Dataset,
num_partitions: int,
Expand All @@ -179,9 +185,10 @@ def _power_law_split(
mean: float = 0.0,
sigma: float = 2.0,
) -> Dataset:
"""Partitions the dataset following a power-law distribution. It follows
the implementation of Li et al 2020: https://arxiv.org/abs/1812.06127 with
default values set accordingly.
"""Partition the dataset following a power-law distribution. It follows the.

implementation of Li et al 2020: https://arxiv.org/abs/1812.06127 with default
values set accordingly.

Parameters
----------
Expand All @@ -205,15 +212,14 @@ def _power_law_split(
Dataset
The partitioned training dataset.
"""

targets = sorted_trainset.targets
full_idx = range(len(targets))
full_idx = list(range(len(targets)))

class_counts = np.bincount(sorted_trainset.targets)
labels_cs = np.cumsum(class_counts)
labels_cs = [0] + labels_cs[:-1].tolist()

partitions_idx = []
partitions_idx: List[List[int]] = []
num_classes = len(np.bincount(targets))
hist = np.zeros(num_classes, dtype=np.int32)

Expand Down Expand Up @@ -243,7 +249,9 @@ def _power_law_split(
(num_classes, int(num_partitions / num_classes), num_labels_per_partition),
)
remaining_per_class = class_counts - hist
# obtain how many samples each partition should be assigned for each of the labels it contains
# obtain how many samples each partition should be assigned for each of the
# labels it contains
# pylint: disable=too-many-function-args
probs = (
remaining_per_class.reshape(-1, 1, 1)
* probs
Expand Down
11 changes: 8 additions & 3 deletions baselines/fedprox/fedprox/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Runs CNN federated learning for MNIST dataset."""

from typing import Dict, Union

import flwr as fl
import hydra
from hydra.core.hydra_config import HydraConfig
Expand All @@ -10,17 +12,18 @@
from fedprox.dataset import load_datasets
from fedprox.utils import save_results_as_pickle

FitConfig = Dict[str, Union[bool, float]]


@hydra.main(config_path="conf", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
"""Main function to run CNN federated learning on MNIST.
"""Run CNN federated learning on MNIST.

Parameters
----------
cfg : DictConfig
An omegaconf object that stores the hydra config.
"""

# print config structured as YAML
print(OmegaConf.to_yaml(cfg))

Expand Down Expand Up @@ -53,7 +56,9 @@ def main(cfg: DictConfig) -> None:
def get_on_fit_config():
def fit_config_fn(server_round: int):
# resolve and convert to python dict
fit_config = OmegaConf.to_container(cfg.fit_config, resolve=True)
fit_config: FitConfig = OmegaConf.to_container( # type: ignore
cfg.fit_config, resolve=True
)
fit_config["curr_round"] = server_round # add round info
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
return fit_config

Expand Down
6 changes: 3 additions & 3 deletions baselines/fedprox/fedprox/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""CNN model architecutre, training, and testing functions for MNIST."""
"""CNN model architecture, training, and testing functions for MNIST."""


from typing import List, Tuple
Expand Down Expand Up @@ -61,7 +61,7 @@ class LogisticRegression(nn.Module):

def __init__(self, num_classes: int) -> None:
super().__init__()
self.fc = nn.Linear(28 * 28, num_classes)
self.linear = nn.Linear(28 * 28, num_classes)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Expand All @@ -76,7 +76,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
torch.Tensor
The resulting Tensor after it has passed through the network
"""
output_tensor = self.fc(torch.flatten(input_tensor, 1))
output_tensor = self.linear(torch.flatten(input_tensor, 1))
return output_tensor


Expand Down
Loading