Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(examples) Update vit-finetune example #3935

Merged
merged 11 commits into from
Aug 24, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Other [examples](https:/adap/flower/tree/main/examples):
- [Vertical FL](https:/adap/flower/tree/main/examples/vertical-fl)
- [Federated Finetuning of OpenAI's Whisper](https:/adap/flower/tree/main/examples/whisper-federated-finetuning)
- [Federated Finetuning of Large Language Model](https:/adap/flower/tree/main/examples/llm-flowertune)
- [Federated Finetuning of a Vision Transformer](https:/adap/flower/tree/main/examples/vit-finetune)
- [Federated Finetuning of a Vision Transformer](https:/adap/flower/tree/main/examples/flowertune-vit)
- [Advanced Flower with TensorFlow/Keras](https:/adap/flower/tree/main/examples/advanced-tensorflow)
- [Advanced Flower with PyTorch](https:/adap/flower/tree/main/examples/advanced-pytorch)
- Single-Machine Simulation of Federated Learning Systems ([PyTorch](https:/adap/flower/tree/main/examples/simulation-pytorch)) ([Tensorflow](https:/adap/flower/tree/main/examples/simulation-tensorflow))
Expand Down
1 change: 1 addition & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def find_test_modules(package_path):
"apiref-binaries": "ref-api-cli.html",
"fedbn-example-pytorch-from-centralized-to-federated": "example-fedbn-pytorch-from-centralized-to-federated.html",
"how-to-use-built-in-middleware-layers": "how-to-use-built-in-mods.html",
"vit-finetune": "flowertune-vit.html",
# Restructuring: tutorials
"tutorial/Flower-0-What-is-FL": "tutorial-series-what-is-federated-learning.html",
"tutorial/Flower-1-Intro-to-FL-PyTorch": "tutorial-series-get-started-with-flower-pytorch.html",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,68 +1,78 @@
---
title: Federated finetuning of a ViT
tags: [finetuneing, vision, fds]
tags: [finetuning, vision, fds]
dataset: [Oxford Flower-102]
framework: [torch, torchvision]
---

# Federated finetuning of a ViT
# Federated Finetuning of a Vision Transformer with Flower

This example shows how to use Flower's Simulation Engine to federate the finetuning of a Vision Transformer ([ViT-Base-16](https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16)) that has been pretrained on ImageNet. To keep things simple we'll be finetuning it to [Oxford Flower-102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html) datasset, creating 20 partitions using [Flower Datasets](https://flower.ai/docs/datasets/). We'll be finetuning just the exit `head` of the ViT, this means that the training is not that costly and each client requires just ~1GB of VRAM (for a batch size of 32 images).
This example shows how to use Flower's Simulation Engine to federate the finetuning of a Vision Transformer ([ViT-Base-16](https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16)) that has been pretrained on ImageNet. To keep things simple we'll be finetuning it to [Oxford Flower-102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html) datasset, creating 20 partitions using [Flower Datasets](https://flower.ai/docs/datasets/). We'll be finetuning just the exit `head` of the ViT, this means that the training is not that costly and each client requires just ~1GB of VRAM (for a batch size of 32 images) if you choose to use a GPU.

## Running the example
## Set up the project

If you haven't cloned the Flower repository already you might want to clone code example and discard the rest. We prepared a single-line command that you can copy into your shell which will checkout the example for you:
### Clone the project

Start by cloning the example project:

```shell
git clone --depth=1 https:/adap/flower.git && mv flower/examples/vit-finetune . && rm -rf flower && cd vit-finetune
git clone --depth=1 https:/adap/flower.git _tmp \
&& mv _tmp/examples/flowertune-vit . \
&& rm -rf _tmp \
&& cd flowertune-vit
```

This will create a new directory called `vit-finetune` containing the following files:
This will create a new directory called `flowertune-vit` with the following structure:

```shell
flowertune-vit
├── vitexample
│ ├── __init__.py
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ └── task.py # Defines your model, training and data loading
├── pyproject.toml # Project metadata like dependencies and configs
└── README.md
```
-- README.md <- Your're reading this right now
-- main.py <- Main file that launches the simulation
-- client.py <- Contains Flower client code and ClientApp
-- server.py <- Contains Flower server code and ServerApp
-- model.py <- Defines model and train/eval functions
-- dataset.py <- Downloads, partitions and processes dataset
-- pyproject.toml <- Example dependencies, installable using Poetry
-- requirements.txt <- Example dependencies, installable using pip
```

### Installing Dependencies

Project dependencies (such as `torch` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.
### Install dependencies and project

#### Poetry
Install the dependencies defined in `pyproject.toml` as well as the `vitexample` package.

```shell
poetry install
poetry shell
```bash
pip install -e .
```

#### pip
## Run the project

With an activated environemnt, install the dependencies for this example:
You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine.

```shell
pip install -r requirements.txt
### Run with the Simulation Engine

> \[!TIP\]
> This example runs faster when the `ClientApp`s have access to a GPU. If your system has one, you can make use of it by configuring the `backend.client-resources` component in `pyproject.toml`. If you want to try running the example with GPU right away, use the `local-simulation-gpu` federation as shown below.
```bash
# Run with the default federation (CPU only)
flwr run .
```

### Run with `start_simulation()`
You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example:

```bash
flwr run . --run-config num-server-rounds=5,batch-size=64
```

Running the example is quite straightforward. You can control the number of rounds `--num-rounds` (which defaults to 20).
Run the project in the `local-simulation-gpu` federation that gives CPU and GPU resources to each `ClientApp`. By default, at most 5x`ClientApp` will run in parallel in the available GPU. You can tweak the degree of parallelism by adjusting the settings of this federation in the `pyproject.toml`.

```bash
python main.py
# Run with the `local-simulation-gpu` federation
flwr run . local-simulation-gpu
```

![](_static/central_evaluation.png)

Running the example as-is on an RTX 3090Ti should take ~15s/round running 5 clients in parallel (plus the _global model_ during centralized evaluation stages) in a single GPU. Note that more clients could fit in VRAM, but since the GPU utilization is high (99%-100%) we are probably better off not doing that (at least in this case).

You can adjust the `client_resources` passed to `start_simulation()` so more/less clients run at the same time in the GPU. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation.

```bash
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07 Driver Version: 535.161.07 CUDA Version: 12.2 |
Expand Down Expand Up @@ -90,12 +100,7 @@ You can adjust the `client_resources` passed to `start_simulation()` so more/les
+---------------------------------------------------------------------------------------+
```

### Run with Flower Next (preview)
### Run with the Deployment Engine

```bash
flower-simulation \
--client-app=client:app \
--server-app=server:app \
--num-supernodes=20 \
--backend-config='{"client_resources": {"num_cpus":4, "num_gpus":0.25}}'
```
> \[!NOTE\]
> An update to this example will show how to run this Flower project with the Deployment Engine and TLS certificates, or with Docker.
43 changes: 43 additions & 0 deletions examples/flowertune-vit/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "vitexample"
version = "1.0.0"
description = "Federated Finetuning of a Vision Transformer with Flower"
license = "Apache-2.0"
dependencies = [
"flwr-nightly[simulation]==1.11.0.dev20240823",
"flwr-datasets[vision]>=0.3.0",
"torch==2.2.1",
"torchvision==0.17.1",
]

[tool.hatch.build.targets.wheel]
packages = ["."]

[tool.flwr.app]
publisher = "flwrlabs"

[tool.flwr.app.components]
serverapp = "vitexample.server_app:app"
clientapp = "vitexample.client_app:app"

[tool.flwr.app.config]
num-server-rounds = 3
batch-size = 32
learning-rate = 0.01
dataset-name = "nelorth/oxford-flowers"
num-classes = 102

[tool.flwr.federations]
default = "local-simulation"

[tool.flwr.federations.local-simulation]
options.num-supernodes = 10

[tool.flwr.federations.local-simulation-gpu]
options.num-supernodes = 10
options.backend.client-resources.num-cpus = 2 # each ClientApp assumes to use 2CPUs
options.backend.client-resources.num-gpus = 0.2 # at most 5 ClientApp will run in a given GPU
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions examples/flowertune-vit/vitexample/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""vitexample: A Flower / PyTorch app with Vision Transformers."""
62 changes: 62 additions & 0 deletions examples/flowertune-vit/vitexample/client_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""vitexample: A Flower / PyTorch app with Vision Transformers."""

import torch
from torch.utils.data import DataLoader

from flwr.common import Context
from flwr.client import NumPyClient, ClientApp


from vitexample.task import apply_train_transforms, get_dataset_partition
from vitexample.task import get_model, set_params, get_params, train


class FedViTClient(NumPyClient):
def __init__(self, trainloader, learning_rate, num_classes):
self.trainloader = trainloader
self.learning_rate = learning_rate
self.model = get_model(num_classes)

# Determine device
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model.to(self.device) # send model to device

def fit(self, parameters, config):
set_params(self.model, parameters)

# Set optimizer
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
# Train locally
avg_train_loss = train(
self.model, self.trainloader, optimizer, epochs=1, device=self.device
)
# Return locally-finetuned part of the model
return (
get_params(self.model),
len(self.trainloader.dataset),
{"train_loss": avg_train_loss},
)


def client_fn(context: Context):
"""Return a FedViTClient."""

# Read the node_config to fetch data partition associated to this node
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
dataset_name = context.run_config["dataset-name"]
trainpartition = get_dataset_partition(num_partitions, partition_id, dataset_name)

batch_size = context.run_config["batch-size"]
lr = context.run_config["learning-rate"]
num_classes = context.run_config["num-classes"]
trainset = trainpartition.with_transform(apply_train_transforms)

trainloader = DataLoader(
trainset, batch_size=batch_size, num_workers=2, shuffle=True
)

return FedViTClient(trainloader, lr, num_classes).to_client()


app = ClientApp(client_fn=client_fn)
77 changes: 77 additions & 0 deletions examples/flowertune-vit/vitexample/server_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""vitexample: A Flower / PyTorch app with Vision Transformers."""

from logging import INFO

import torch
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader

from vitexample.task import apply_eval_transforms
from vitexample.task import get_model, set_params, test, get_params

from flwr.common import Context, ndarrays_to_parameters
from flwr.common.logger import log
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg


def get_evaluate_fn(
centralized_testset: Dataset,
num_classes: int,
):
"""Return an evaluation function for centralized evaluation."""

def evaluate(server_round, parameters, config):
"""Use the entire Oxford Flowers-102 test set for evaluation."""

# Determine device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Instantiate model and apply current global parameters
model = get_model(num_classes)
set_params(model, parameters)
model.to(device)

# Apply transform to dataset
testset = centralized_testset.with_transform(apply_eval_transforms)

testloader = DataLoader(testset, batch_size=128)
# Run evaluation
loss, accuracy = test(model, testloader, device=device)
log(INFO, f"round: {server_round} -> acc: {accuracy:.4f}, loss: {loss: .4f}")

return loss, {"accuracy": accuracy}

return evaluate


def server_fn(context: Context):

# Define tested for central evaluation
dataset_name = context.run_config["dataset-name"]
dataset = load_dataset(dataset_name)
test_set = dataset["test"]

# Set initial global model
num_classes = context.run_config["num-classes"]
ndarrays = get_params(get_model(num_classes))
init_parameters = ndarrays_to_parameters(ndarrays)

# Configure the strategy
strategy = FedAvg(
fraction_fit=0.5, # Sample 50% of available clients
fraction_evaluate=0.0, # No federated evaluation
evaluate_fn=get_evaluate_fn(
test_set, num_classes
), # Global evaluation function
initial_parameters=init_parameters,
)

# Construct ServerConfig
num_rounds = context.run_config["num-server-rounds"]
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)


app = ServerApp(server_fn=server_fn)
Loading
Loading