-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add other flagship instruction dataset builders #541
Merged
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,6 @@ torchtune.datasets | |
:nosignatures: | ||
|
||
alpaca_dataset | ||
grammar_dataset | ||
samsum_dataset | ||
SlimOrcaDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from tests.test_utils import get_assets_path | ||
from torchtune.datasets._common import CROSS_ENTROPY_IGNORE_IDX | ||
|
||
from torchtune.datasets._grammar import grammar_dataset | ||
from torchtune.modules.tokenizer import Tokenizer | ||
|
||
|
||
class TestGrammarDataset: | ||
@pytest.fixture | ||
def tokenizer(self): | ||
# m.model is a pretrained Sentencepiece model using the following command: | ||
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000') | ||
return Tokenizer.from_file(str(get_assets_path() / "m.model")) | ||
|
||
@patch("torchtune.datasets._instruct.load_dataset") | ||
def test_label_no_masking(self, load_dataset, tokenizer): | ||
""" | ||
Test whether the input and the labels are correctly created when the input is not masked. | ||
""" | ||
|
||
# mock the call to HF datasets | ||
load_dataset.return_value = [ | ||
{ | ||
"input": "Bitcoin is for $7,094 this morning, which CoinDesk says.", | ||
"output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.", | ||
} | ||
] | ||
|
||
grammar_ds = grammar_dataset(tokenizer=tokenizer) | ||
input, labels = grammar_ds[0] | ||
|
||
assert len(input) == len(labels) | ||
assert labels[-1] == tokenizer.eos_id | ||
assert input[0] == tokenizer.bos_id | ||
assert CROSS_ENTROPY_IGNORE_IDX not in labels | ||
|
||
@patch("torchtune.datasets._instruct.load_dataset") | ||
def test_label_masking(self, load_dataset, tokenizer): | ||
""" | ||
Test whether the input and the labels are correctly created when the input is masked. | ||
""" | ||
|
||
# mock the call to HF datasets | ||
load_dataset.return_value = [ | ||
{ | ||
"input": "Bitcoin is for $7,094 this morning, which CoinDesk says.", | ||
"output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.", | ||
} | ||
] | ||
|
||
grammar_ds = grammar_dataset(tokenizer=tokenizer, train_on_input=False) | ||
|
||
# Extract the prompt and tokenize it; we'll need this to test whether we're masking the | ||
# input correctly | ||
sample = grammar_ds._data[0] | ||
prompt = grammar_ds.template.format( | ||
sample=sample, column_map={"sentence": "input"} | ||
) | ||
encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False) | ||
|
||
# Generate the input and labels | ||
input, labels = grammar_ds[0] | ||
|
||
assert len(input) == len(labels) | ||
assert labels[-1] == tokenizer.eos_id | ||
assert input[0] == tokenizer.bos_id | ||
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from tests.test_utils import get_assets_path | ||
from torchtune.datasets._common import CROSS_ENTROPY_IGNORE_IDX | ||
|
||
from torchtune.datasets._samsum import samsum_dataset | ||
from torchtune.modules.tokenizer import Tokenizer | ||
|
||
|
||
class TestSamsumDataset: | ||
@pytest.fixture | ||
def tokenizer(self): | ||
# m.model is a pretrained Sentencepiece model using the following command: | ||
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000') | ||
return Tokenizer.from_file(str(get_assets_path() / "m.model")) | ||
|
||
@patch("torchtune.datasets._instruct.load_dataset") | ||
def test_label_no_masking(self, load_dataset, tokenizer): | ||
""" | ||
Test whether the input and the labels are correctly created when the input is not masked. | ||
""" | ||
|
||
# mock the call to HF datasets | ||
load_dataset.return_value = [ | ||
{ | ||
"id": "13818513", | ||
"dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", | ||
"summary": "Amanda baked cookies and will bring Jerry some tomorrow.", | ||
}, | ||
] | ||
|
||
samsum_ds = samsum_dataset(tokenizer=tokenizer) | ||
input, labels = samsum_ds[0] | ||
|
||
assert len(input) == len(labels) | ||
assert labels[-1] == tokenizer.eos_id | ||
assert input[0] == tokenizer.bos_id | ||
assert CROSS_ENTROPY_IGNORE_IDX not in labels | ||
|
||
@patch("torchtune.datasets._instruct.load_dataset") | ||
def test_label_masking(self, load_dataset, tokenizer): | ||
""" | ||
Test whether the input and the labels are correctly created when the input is masked. | ||
""" | ||
|
||
# mock the call to HF datasets | ||
load_dataset.return_value = [ | ||
{ | ||
"id": "13818513", | ||
"dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", | ||
"summary": "Amanda baked cookies and will bring Jerry some tomorrow.", | ||
}, | ||
] | ||
|
||
samsum_ds = samsum_dataset(tokenizer=tokenizer, train_on_input=False) | ||
|
||
# Extract the prompt and tokenize it; we'll need this to test whether we're masking the | ||
# input correctly | ||
sample = samsum_ds._data[0] | ||
prompt = samsum_ds.template.format(sample=sample) | ||
encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False) | ||
|
||
# Generate the input and labels | ||
input, labels = samsum_ds[0] | ||
|
||
assert len(input) == len(labels) | ||
assert labels[-1] == tokenizer.eos_id | ||
assert input[0] == tokenizer.bos_id | ||
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtune.data import GrammarErrorCorrectionTemplate | ||
from torchtune.datasets._instruct import InstructDataset | ||
from torchtune.modules import Tokenizer | ||
|
||
|
||
def grammar_dataset( | ||
tokenizer: Tokenizer, | ||
train_on_input: bool = True, | ||
) -> InstructDataset: | ||
""" | ||
Support for the Grammar dataset and its variants from HuggingFace Datasets. | ||
https://huggingface.co/datasets/liweili/c4_200m | ||
|
||
Data input format: https://huggingface.co/datasets/liweili/c4_200m#description | ||
|
||
The prompt template is created from llama_recipes codebase: | ||
https:/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py#L50 | ||
|
||
where `input` and `output` are fields from the dataset. | ||
|
||
Masking of the prompt during training is controlled by the `train_on_input` flag, which is | ||
set to `True` by default | ||
- If `train_on_input` is True, the prompt is used during training and | ||
contributes to the loss. | ||
- If `train_on_input` is False, the prompt is masked out (tokens replaced with -100) | ||
|
||
Args: | ||
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. | ||
train_on_input (bool): Whether the model is trained on the prompt or not. Default is True. | ||
|
||
Returns: | ||
InstructDataset: dataset configured with Grammar source data and template | ||
|
||
|
||
Example: | ||
>>> grammar_ds = grammar_dataset(tokenizer=tokenizer) | ||
>>> for batch in Dataloader(grammar_ds, batch_size=8): | ||
>>> print(f"Batch size: {len(batch)}") | ||
>>> Batch size: 8 | ||
""" | ||
|
||
return InstructDataset( | ||
tokenizer=tokenizer, | ||
source="liweili/c4_200m", | ||
template=GrammarErrorCorrectionTemplate(), | ||
column_map={"sentence": "input"}, | ||
train_on_input=train_on_input, | ||
split="train", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtune.data import SummarizeTemplate | ||
from torchtune.datasets._instruct import InstructDataset | ||
from torchtune.modules import Tokenizer | ||
|
||
|
||
def samsum_dataset( | ||
tokenizer: Tokenizer, | ||
train_on_input: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think by default we want to leave train_on_input on False, as I've seen most datasets do mask out the prompt when computing loss (except alpaca) |
||
) -> InstructDataset: | ||
""" | ||
Support for the Summarize dataset and its variants from HuggingFace Datasets. | ||
https://huggingface.co/datasets/samsum | ||
|
||
Data input format: https://huggingface.co/datasets/samsum#data-fields | ||
|
||
The prompt template is created from llama_recipes codebase: | ||
https:/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/samsum_dataset.py#L13 | ||
|
||
where `dialogue` and `summary` are fields from the dataset. | ||
|
||
Masking of the prompt during training is controlled by the `train_on_input` flag, which is | ||
set to `True` by default | ||
- If `train_on_input` is True, the prompt is used during training and | ||
contributes to the loss. | ||
- If `train_on_input` is False, the prompt is masked out (tokens replaced with -100) | ||
|
||
Args: | ||
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. | ||
train_on_input (bool): Whether the model is trained on the prompt or not. Default is True. | ||
|
||
Returns: | ||
InstructDataset: dataset configured with Summarization source data and template | ||
|
||
|
||
Example: | ||
>>> samsum_ds = samsum_dataset(tokenizer=tokenizer) | ||
>>> for batch in Dataloader(samsum_ds, batch_size=8): | ||
>>> print(f"Batch size: {len(batch)}") | ||
>>> Batch size: 8 | ||
""" | ||
|
||
return InstructDataset( | ||
tokenizer=tokenizer, | ||
source="samsum", | ||
template=SummarizeTemplate(), | ||
column_map={"output": "summary"}, | ||
train_on_input=train_on_input, | ||
split="train", | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment below