Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add FlowerTune templates to flwr new #3587

Merged
merged 51 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
41e231b
Init flwrtune to flwr new
yan-gao-GY Jun 11, 2024
39a448f
Update flwr run for backend_config passing
yan-gao-GY Jun 11, 2024
ab19093
Update
yan-gao-GY Jun 11, 2024
c828db8
Init flwr new with 4 LLM tasks
yan-gao-GY Jun 12, 2024
cd7a896
Formatting
yan-gao-GY Jun 12, 2024
df70a0f
Update src/py/flwr/cli/new/templates/app/pyproject.flwrtune.toml.tpl
yan-gao-GY Jun 12, 2024
7787d51
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 12, 2024
70e5d7a
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 12, 2024
4b6720b
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 12, 2024
421ddce
Fix config files
yan-gao-GY Jun 12, 2024
922c0c5
Update pyproject.flwrtune.toml.tpl
yan-gao-GY Jun 12, 2024
8e60056
Update pyproject.flwrtune.toml.tpl
yan-gao-GY Jun 12, 2024
51ba186
Fix
yan-gao-GY Jun 13, 2024
a318ac4
Fix
yan-gao-GY Jun 13, 2024
ca33364
Fix
yan-gao-GY Jun 13, 2024
5817d47
Avoid warnings
yan-gao-GY Jun 13, 2024
9b04e87
Fix
yan-gao-GY Jun 13, 2024
b9ab673
Update readme and formatting
yan-gao-GY Jun 13, 2024
f4cc70a
Formatting
yan-gao-GY Jun 13, 2024
2fc3c58
Update src/py/flwr/cli/new/templates/app/pyproject.flwrtune.toml.tpl
yan-gao-GY Jun 13, 2024
56ebd7a
Formatting
yan-gao-GY Jun 13, 2024
8bbdff3
Update src/py/flwr/cli/new/templates/app/code/flwrtune/config.yaml.tpl
jafermarq Jun 14, 2024
fdc867a
Merge branch 'main' into add-flwrtune-flwrnew
jafermarq Jun 14, 2024
83b30e6
Merge branch 'main' into add-flwrtune-flwrnew
jafermarq Jun 18, 2024
27cae47
Update readme
yan-gao-GY Jun 18, 2024
3bcf6dd
Replace `task` with `challenge`
yan-gao-GY Jun 18, 2024
9fd9a3f
Formatting
yan-gao-GY Jun 18, 2024
95e65dd
Update readme
yan-gao-GY Jun 18, 2024
29d78cd
Update src/py/flwr/cli/new/templates/app/README.flwrtune.md.tpl
yan-gao-GY Jun 18, 2024
5648a79
Update src/py/flwr/cli/new/templates/app/README.flwrtune.md.tpl
yan-gao-GY Jun 18, 2024
bfe3e1a
Update src/py/flwr/cli/new/templates/app/README.flwrtune.md.tpl
yan-gao-GY Jun 18, 2024
fabe74a
Update src/py/flwr/cli/new/templates/app/README.flwrtune.md.tpl
yan-gao-GY Jun 18, 2024
2a9ea73
Update readme
yan-gao-GY Jun 18, 2024
b634864
Merge branch 'main' into add-flwrtune-flwrnew
yan-gao-GY Jun 18, 2024
28c5602
Update readme
yan-gao-GY Jun 18, 2024
c0085ed
Update src/py/flwr/cli/new/new.py
yan-gao-GY Jun 20, 2024
5056153
Update src/py/flwr/cli/new/new.py
yan-gao-GY Jun 20, 2024
8ccf93f
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
cd6d02e
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
4cc871a
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
69a6693
Update src/py/flwr/cli/new/templates/app/code/flwrtune/client.py.tpl
yan-gao-GY Jun 20, 2024
dc9e08e
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
6368699
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
6ec4288
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
bf07c4f
Formatting
yan-gao-GY Jun 20, 2024
76093e3
Change model parameter init method & formatting
yan-gao-GY Jun 20, 2024
061275b
Merge branch 'main' into add-flwrtune-flwrnew
yan-gao-GY Jun 20, 2024
dc34b16
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 21, 2024
d48c2a2
Update FlowerTune names
yan-gao-GY Jun 21, 2024
4a61686
Formatting
yan-gao-GY Jun 21, 2024
e4f66ed
Merge branch 'main' into add-flwrtune-flwrnew
yan-gao-GY Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 95 additions & 27 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ class MlFramework(str, Enum):
HUGGINGFACE = "HF"
MLX = "MLX"
SKLEARN = "sklearn"
FLWRTUNE = "flwrtune"


class LLMTaskName(str, Enum):
"""Available LLM tasks."""

GENERALNLP = "GeneralNLP"
FINANCE = "Finance"
MEDICAL = "Medical"
CODE = "Code"


class TemplateNotFound(Exception):
Expand Down Expand Up @@ -81,6 +91,7 @@ def render_and_create(file_path: str, template: str, context: Dict[str, str]) ->
create_file(file_path, content)


# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def new(
project_name: Annotated[
Optional[str],
Expand Down Expand Up @@ -125,6 +136,17 @@ def new(

framework_str = framework_str.lower()

if framework_str == "flwrtune":
llm_task_value = prompt_options(
"Please select LLM task by typing in the number",
sorted([task.value for task in LLMTaskName]),
)
selected_value = [
name for name, value in vars(LLMTaskName).items() if value == llm_task_value
]
llm_task_str = selected_value[0]
llm_task_str = llm_task_str.lower()

print(
typer.style(
f"\n🔨 Creating Flower project {project_name}...",
Expand All @@ -139,40 +161,86 @@ def new(
import_name = package_name.replace("-", "_")
project_dir = os.path.join(cwd, package_name)

# List of files to render
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"README.md": {"template": "app/README.md.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server.py": {
"template": f"app/code/server.{framework_str}.py.tpl"
},
f"{import_name}/client.py": {
"template": f"app/code/client.{framework_str}.py.tpl"
},
}

# Depending on the framework, generate task.py file
frameworks_with_tasks = [
MlFramework.PYTORCH.value.lower(),
MlFramework.JAX.value.lower(),
MlFramework.HUGGINGFACE.value.lower(),
MlFramework.MLX.value.lower(),
MlFramework.TENSORFLOW.value.lower(),
]
if framework_str in frameworks_with_tasks:
files[f"{import_name}/task.py"] = {
"template": f"app/code/task.{framework_str}.py.tpl"
}

context = {
"project_name": project_name,
"package_name": package_name,
"import_name": import_name.replace("-", "_"),
"username": username,
}

# List of files to render
if framework_str == "flwrtune":
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
"README.md": {"template": f"app/README.{framework_str}.md.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server.py": {"template": "app/code/flwrtune/server.py.tpl"},
f"{import_name}/client.py": {"template": "app/code/flwrtune/client.py.tpl"},
f"{import_name}/app.py": {"template": "app/code/flwrtune/app.py.tpl"},
f"{import_name}/models.py": {"template": "app/code/flwrtune/models.py.tpl"},
f"{import_name}/dataset.py": {
"template": "app/code/flwrtune/dataset.py.tpl"
},
f"{import_name}/conf/config.yaml": {
"template": "app/code/flwrtune/config.yaml.tpl"
},
f"{import_name}/conf/static_config.yaml": {
"template": "app/code/flwrtune/static_config.yaml.tpl"
},
}

# Task specific context
fraction_fit = "0.2" if llm_task_str == "code" else "0.1"
if llm_task_str == "generalnlp":
task_name = "General NLP"
num_clients = "20"
dataset_name = "vicgalle/alpaca-gpt4"
elif llm_task_str == "finance":
task_name = "Finance"
num_clients = "50"
dataset_name = "FinGPT/fingpt-sentiment-train"
elif llm_task_str == "medical":
task_name = "Medical"
num_clients = "20"
dataset_name = "medalpaca/medical_meadow_medical_flashcards"
else:
task_name = "Code"
num_clients = "10"
dataset_name = "lucasmccabe-lmi/CodeAlpaca-20k"

context["llm_task_str"] = llm_task_str
context["fraction_fit"] = fraction_fit
context["task_name"] = task_name
context["num_clients"] = num_clients
context["dataset_name"] = dataset_name
else:
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"README.md": {"template": "app/README.md.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server.py": {
"template": f"app/code/server.{framework_str}.py.tpl"
},
f"{import_name}/client.py": {
"template": f"app/code/client.{framework_str}.py.tpl"
},
}

# Depending on the framework, generate task.py file
frameworks_with_tasks = [
MlFramework.PYTORCH.value.lower(),
MlFramework.JAX.value.lower(),
MlFramework.HUGGINGFACE.value.lower(),
MlFramework.MLX.value.lower(),
MlFramework.TENSORFLOW.value.lower(),
]
if framework_str in frameworks_with_tasks:
files[f"{import_name}/task.py"] = {
"template": f"app/code/task.{framework_str}.py.tpl"
}

for file_path, value in files.items():
render_and_create(
file_path=os.path.join(project_dir, file_path),
Expand Down
41 changes: 41 additions & 0 deletions src/py/flwr/cli/new/templates/app/README.flwrtune.md.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# FlowerTune-LLM on $task_name Dataset

This directory conducts federated instruction tuning with a pretrained [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model on a $task_name dataset.
We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset.
Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way,
which allows users to perform the training on a single GPU.

## Methodology
This baseline performs federated LLM fine-tuning with [LoRA](https://arxiv.org/pdf/2106.09685) using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library.
The clients' models are aggregated with FedAvg strategy.
This provides a baseline performance for the leaderboard of General NLP task.


## Environments setup
Project dependencies are defined in `pyproject.toml`. Install them with:

```shell
pip install -e .
```

## Experimental setup
The dataset is partitioned into $num_clients shards with IID fashion serving as clients.
We randomly sample $fraction_fit clients to be available for each round,
and the federated fine-tuning lasts for `200` rounds.
All settings are defined in `$project_name/conf/static_config.yaml`, which is not allowed to be modified for fair competition.
yan-gao-GY marked this conversation as resolved.
Show resolved Hide resolved


## Running the task
First make sure that you have got the access to [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model with your Hugging-Face account.
yan-gao-GY marked this conversation as resolved.
Show resolved Hide resolved
Then log in with your [User Access Token](https://huggingface.co/docs/hub/security-tokens).

```bash
huggingface-cli login --token XXXXXXX
```

With an activated Python environment, run the task with default config values.
The configs are in `$project_name/conf/config.yaml` and `$project_name/conf/static_config.yaml`, and are loaded automatically.

```bash
flwr run
```
15 changes: 15 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/flwrtune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower CLI `new` command app / code / flwrtune templates."""
77 changes: 77 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""$project_name: A Flower / flwrtune app."""

import os
import warnings
from datetime import datetime
from hydra import compose, initialize
from hydra.utils import instantiate
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting

import flwr as fl
from flwr_datasets import FederatedDataset

from $import_name.client import gen_client_fn
from $import_name.server import get_on_fit_config, fit_weighted_average, get_evaluate_fn


# Avoid warnings
warnings.filterwarnings("ignore", category=UserWarning)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"

# Initialise regular config
with initialize(config_path="conf", version_base="1.1"):
cfg = compose(config_name="config")

# Initialise static config
with initialize(config_path="conf", version_base="1.1"):
cfg_static = compose(config_name="static_config")

cfg.train.num_rounds = cfg_static.num_rounds

# Create output directory given current timestamp
current_time = datetime.now()
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
os.makedirs(save_path, exist_ok=True)

# Partition dataset and get dataloaders
partitioner = instantiate(cfg_static.partitioner)
fds = FederatedDataset(
dataset=cfg_static.dataset.name, partitioners={"train": partitioner}
)
(
tokenizer,
data_collator,
formatting_prompts_func,
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)

# ClientApp for Flower-Next
yan-gao-GY marked this conversation as resolved.
Show resolved Hide resolved
client = fl.client.ClientApp(
yan-gao-GY marked this conversation as resolved.
Show resolved Hide resolved
client_fn=gen_client_fn(
fds,
tokenizer,
formatting_prompts_func,
data_collator,
cfg.model,
cfg.train,
save_path,
),
)

# Instantiate strategy according to config. Here we pass other arguments
# that are only defined at run time.
yan-gao-GY marked this conversation as resolved.
Show resolved Hide resolved
strategy = instantiate(
cfg.strategy,
on_fit_config_fn=get_on_fit_config(),
fit_metrics_aggregation_fn=fit_weighted_average,
evaluate_fn=get_evaluate_fn(
cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
),
)

# ServerApp for Flower-Next
yan-gao-GY marked this conversation as resolved.
Show resolved Hide resolved
server = fl.server.ServerApp(
yan-gao-GY marked this conversation as resolved.
Show resolved Hide resolved
config=fl.server.ServerConfig(num_rounds=cfg_static.num_rounds),
yan-gao-GY marked this conversation as resolved.
Show resolved Hide resolved
strategy=strategy,
)
Loading