Skip to content

Commit

Permalink
refactor(examples) Update vit-finetune example (#3935)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Aug 24, 2024
1 parent ecac7f5 commit 75ea504
Show file tree
Hide file tree
Showing 16 changed files with 362 additions and 384 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Other [examples](https:/adap/flower/tree/main/examples):
- [Vertical FL](https:/adap/flower/tree/main/examples/vertical-fl)
- [Federated Finetuning of OpenAI's Whisper](https:/adap/flower/tree/main/examples/whisper-federated-finetuning)
- [Federated Finetuning of Large Language Model](https:/adap/flower/tree/main/examples/llm-flowertune)
- [Federated Finetuning of a Vision Transformer](https:/adap/flower/tree/main/examples/vit-finetune)
- [Federated Finetuning of a Vision Transformer](https:/adap/flower/tree/main/examples/flowertune-vit)
- [Advanced Flower with TensorFlow/Keras](https:/adap/flower/tree/main/examples/advanced-tensorflow)
- [Advanced Flower with PyTorch](https:/adap/flower/tree/main/examples/advanced-pytorch)
- Single-Machine Simulation of Federated Learning Systems ([PyTorch](https:/adap/flower/tree/main/examples/simulation-pytorch)) ([Tensorflow](https:/adap/flower/tree/main/examples/simulation-tensorflow))
Expand Down
1 change: 1 addition & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def find_test_modules(package_path):
"apiref-binaries": "ref-api-cli.html",
"fedbn-example-pytorch-from-centralized-to-federated": "example-fedbn-pytorch-from-centralized-to-federated.html",
"how-to-use-built-in-middleware-layers": "how-to-use-built-in-mods.html",
"vit-finetune": "flowertune-vit.html",
# Restructuring: tutorials
"tutorial/Flower-0-What-is-FL": "tutorial-series-what-is-federated-learning.html",
"tutorial/Flower-1-Intro-to-FL-PyTorch": "tutorial-series-get-started-with-flower-pytorch.html",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,68 +1,78 @@
---
title: Federated finetuning of a ViT
tags: [finetuneing, vision, fds]
tags: [finetuning, vision, fds]
dataset: [Oxford Flower-102]
framework: [torch, torchvision]
---

# Federated finetuning of a ViT
# Federated Finetuning of a Vision Transformer with Flower

This example shows how to use Flower's Simulation Engine to federate the finetuning of a Vision Transformer ([ViT-Base-16](https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16)) that has been pretrained on ImageNet. To keep things simple we'll be finetuning it to [Oxford Flower-102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html) datasset, creating 20 partitions using [Flower Datasets](https://flower.ai/docs/datasets/). We'll be finetuning just the exit `head` of the ViT, this means that the training is not that costly and each client requires just ~1GB of VRAM (for a batch size of 32 images).
This example shows how to use Flower's Simulation Engine to federate the finetuning of a Vision Transformer ([ViT-Base-16](https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16)) that has been pretrained on ImageNet. To keep things simple we'll be finetuning it to [Oxford Flower-102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html) datasset, creating 20 partitions using [Flower Datasets](https://flower.ai/docs/datasets/). We'll be finetuning just the exit `head` of the ViT, this means that the training is not that costly and each client requires just ~1GB of VRAM (for a batch size of 32 images) if you choose to use a GPU.

## Running the example
## Set up the project

If you haven't cloned the Flower repository already you might want to clone code example and discard the rest. We prepared a single-line command that you can copy into your shell which will checkout the example for you:
### Clone the project

Start by cloning the example project:

```shell
git clone --depth=1 https:/adap/flower.git && mv flower/examples/vit-finetune . && rm -rf flower && cd vit-finetune
git clone --depth=1 https:/adap/flower.git _tmp \
&& mv _tmp/examples/flowertune-vit . \
&& rm -rf _tmp \
&& cd flowertune-vit
```

This will create a new directory called `vit-finetune` containing the following files:
This will create a new directory called `flowertune-vit` with the following structure:

```shell
flowertune-vit
├── vitexample
│ ├── __init__.py
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ └── task.py # Defines your model, training and data loading
├── pyproject.toml # Project metadata like dependencies and configs
└── README.md
```
-- README.md <- Your're reading this right now
-- main.py <- Main file that launches the simulation
-- client.py <- Contains Flower client code and ClientApp
-- server.py <- Contains Flower server code and ServerApp
-- model.py <- Defines model and train/eval functions
-- dataset.py <- Downloads, partitions and processes dataset
-- pyproject.toml <- Example dependencies, installable using Poetry
-- requirements.txt <- Example dependencies, installable using pip
```

### Installing Dependencies

Project dependencies (such as `torch` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.
### Install dependencies and project

#### Poetry
Install the dependencies defined in `pyproject.toml` as well as the `vitexample` package.

```shell
poetry install
poetry shell
```bash
pip install -e .
```

#### pip
## Run the project

With an activated environemnt, install the dependencies for this example:
You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine.

```shell
pip install -r requirements.txt
### Run with the Simulation Engine

> \[!TIP\]
> This example runs faster when the `ClientApp`s have access to a GPU. If your system has one, you can make use of it by configuring the `backend.client-resources` component in `pyproject.toml`. If you want to try running the example with GPU right away, use the `local-simulation-gpu` federation as shown below.
```bash
# Run with the default federation (CPU only)
flwr run .
```

### Run with `start_simulation()`
You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example:

```bash
flwr run . --run-config num-server-rounds=5,batch-size=64
```

Running the example is quite straightforward. You can control the number of rounds `--num-rounds` (which defaults to 20).
Run the project in the `local-simulation-gpu` federation that gives CPU and GPU resources to each `ClientApp`. By default, at most 5x`ClientApp` will run in parallel in the available GPU. You can tweak the degree of parallelism by adjusting the settings of this federation in the `pyproject.toml`.

```bash
python main.py
# Run with the `local-simulation-gpu` federation
flwr run . local-simulation-gpu
```

![](_static/central_evaluation.png)

Running the example as-is on an RTX 3090Ti should take ~15s/round running 5 clients in parallel (plus the _global model_ during centralized evaluation stages) in a single GPU. Note that more clients could fit in VRAM, but since the GPU utilization is high (99%-100%) we are probably better off not doing that (at least in this case).

You can adjust the `client_resources` passed to `start_simulation()` so more/less clients run at the same time in the GPU. Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customise your simulation.

```bash
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07 Driver Version: 535.161.07 CUDA Version: 12.2 |
Expand Down Expand Up @@ -90,12 +100,7 @@ You can adjust the `client_resources` passed to `start_simulation()` so more/les
+---------------------------------------------------------------------------------------+
```

### Run with Flower Next (preview)
### Run with the Deployment Engine

```bash
flower-simulation \
--client-app=client:app \
--server-app=server:app \
--num-supernodes=20 \
--backend-config='{"client_resources": {"num_cpus":4, "num_gpus":0.25}}'
```
> \[!NOTE\]
> An update to this example will show how to run this Flower project with the Deployment Engine and TLS certificates, or with Docker.
43 changes: 43 additions & 0 deletions examples/flowertune-vit/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "vitexample"
version = "1.0.0"
description = "Federated Finetuning of a Vision Transformer with Flower"
license = "Apache-2.0"
dependencies = [
"flwr-nightly[simulation]==1.11.0.dev20240823",
"flwr-datasets[vision]>=0.3.0",
"torch==2.2.1",
"torchvision==0.17.1",
]

[tool.hatch.build.targets.wheel]
packages = ["."]

[tool.flwr.app]
publisher = "flwrlabs"

[tool.flwr.app.components]
serverapp = "vitexample.server_app:app"
clientapp = "vitexample.client_app:app"

[tool.flwr.app.config]
num-server-rounds = 3
batch-size = 32
learning-rate = 0.01
dataset-name = "nelorth/oxford-flowers"
num-classes = 102

[tool.flwr.federations]
default = "local-simulation"

[tool.flwr.federations.local-simulation]
options.num-supernodes = 10

[tool.flwr.federations.local-simulation-gpu]
options.num-supernodes = 10
options.backend.client-resources.num-cpus = 2 # each ClientApp assumes to use 2CPUs
options.backend.client-resources.num-gpus = 0.2 # at most 5 ClientApp will run in a given GPU
1 change: 1 addition & 0 deletions examples/flowertune-vit/vitexample/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""vitexample: A Flower / PyTorch app with Vision Transformers."""
62 changes: 62 additions & 0 deletions examples/flowertune-vit/vitexample/client_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""vitexample: A Flower / PyTorch app with Vision Transformers."""

import torch
from torch.utils.data import DataLoader

from flwr.common import Context
from flwr.client import NumPyClient, ClientApp


from vitexample.task import apply_train_transforms, get_dataset_partition
from vitexample.task import get_model, set_params, get_params, train


class FedViTClient(NumPyClient):
def __init__(self, trainloader, learning_rate, num_classes):
self.trainloader = trainloader
self.learning_rate = learning_rate
self.model = get_model(num_classes)

# Determine device
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model.to(self.device) # send model to device

def fit(self, parameters, config):
set_params(self.model, parameters)

# Set optimizer
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
# Train locally
avg_train_loss = train(
self.model, self.trainloader, optimizer, epochs=1, device=self.device
)
# Return locally-finetuned part of the model
return (
get_params(self.model),
len(self.trainloader.dataset),
{"train_loss": avg_train_loss},
)


def client_fn(context: Context):
"""Return a FedViTClient."""

# Read the node_config to fetch data partition associated to this node
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
dataset_name = context.run_config["dataset-name"]
trainpartition = get_dataset_partition(num_partitions, partition_id, dataset_name)

batch_size = context.run_config["batch-size"]
lr = context.run_config["learning-rate"]
num_classes = context.run_config["num-classes"]
trainset = trainpartition.with_transform(apply_train_transforms)

trainloader = DataLoader(
trainset, batch_size=batch_size, num_workers=2, shuffle=True
)

return FedViTClient(trainloader, lr, num_classes).to_client()


app = ClientApp(client_fn=client_fn)
77 changes: 77 additions & 0 deletions examples/flowertune-vit/vitexample/server_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""vitexample: A Flower / PyTorch app with Vision Transformers."""

from logging import INFO

import torch
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader

from vitexample.task import apply_eval_transforms
from vitexample.task import get_model, set_params, test, get_params

from flwr.common import Context, ndarrays_to_parameters
from flwr.common.logger import log
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg


def get_evaluate_fn(
centralized_testset: Dataset,
num_classes: int,
):
"""Return an evaluation function for centralized evaluation."""

def evaluate(server_round, parameters, config):
"""Use the entire Oxford Flowers-102 test set for evaluation."""

# Determine device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Instantiate model and apply current global parameters
model = get_model(num_classes)
set_params(model, parameters)
model.to(device)

# Apply transform to dataset
testset = centralized_testset.with_transform(apply_eval_transforms)

testloader = DataLoader(testset, batch_size=128)
# Run evaluation
loss, accuracy = test(model, testloader, device=device)
log(INFO, f"round: {server_round} -> acc: {accuracy:.4f}, loss: {loss: .4f}")

return loss, {"accuracy": accuracy}

return evaluate


def server_fn(context: Context):

# Define tested for central evaluation
dataset_name = context.run_config["dataset-name"]
dataset = load_dataset(dataset_name)
test_set = dataset["test"]

# Set initial global model
num_classes = context.run_config["num-classes"]
ndarrays = get_params(get_model(num_classes))
init_parameters = ndarrays_to_parameters(ndarrays)

# Configure the strategy
strategy = FedAvg(
fraction_fit=0.5, # Sample 50% of available clients
fraction_evaluate=0.0, # No federated evaluation
evaluate_fn=get_evaluate_fn(
test_set, num_classes
), # Global evaluation function
initial_parameters=init_parameters,
)

# Construct ServerConfig
num_rounds = context.run_config["num-server-rounds"]
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)


app = ServerApp(server_fn=server_fn)
Loading

0 comments on commit 75ea504

Please sign in to comment.