Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add other flagship instruction dataset builders #541

Merged
merged 3 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ torchtune.datasets
:nosignatures:

alpaca_dataset
grammar_dataset
samsum_dataset
SlimOrcaDataset
77 changes: 77 additions & 0 deletions tests/torchtune/datasets/test_grammar_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from unittest.mock import patch

import pytest

from tests.test_utils import get_assets_path
from torchtune.datasets._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets._grammar import grammar_dataset
from torchtune.modules.tokenizer import Tokenizer


class TestGrammarDataset:
@pytest.fixture
def tokenizer(self):
# m.model is a pretrained Sentencepiece model using the following command:
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
return Tokenizer.from_file(str(get_assets_path() / "m.model"))

@patch("torchtune.datasets._instruct.load_dataset")
def test_label_no_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is not masked.
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"input": "Bitcoin is for $7,094 this morning, which CoinDesk says.",
"output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.",
}
]

grammar_ds = grammar_dataset(tokenizer=tokenizer)
input, labels = grammar_ds[0]

assert len(input) == len(labels)
assert labels[-1] == tokenizer.eos_id
assert input[0] == tokenizer.bos_id
assert CROSS_ENTROPY_IGNORE_IDX not in labels

@patch("torchtune.datasets._instruct.load_dataset")
def test_label_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is masked.
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"input": "Bitcoin is for $7,094 this morning, which CoinDesk says.",
"output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.",
}
]

grammar_ds = grammar_dataset(tokenizer=tokenizer, train_on_input=False)

# Extract the prompt and tokenize it; we'll need this to test whether we're masking the
# input correctly
sample = grammar_ds._data[0]
prompt = grammar_ds.template.format(
sample=sample, column_map={"sentence": "input"}
)
encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False)

# Generate the input and labels
input, labels = grammar_ds[0]

assert len(input) == len(labels)
assert labels[-1] == tokenizer.eos_id
assert input[0] == tokenizer.bos_id
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt)
77 changes: 77 additions & 0 deletions tests/torchtune/datasets/test_samsum_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from unittest.mock import patch

import pytest

from tests.test_utils import get_assets_path
from torchtune.datasets._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets._samsum import samsum_dataset
from torchtune.modules.tokenizer import Tokenizer


class TestSamsumDataset:
@pytest.fixture
def tokenizer(self):
# m.model is a pretrained Sentencepiece model using the following command:
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
return Tokenizer.from_file(str(get_assets_path() / "m.model"))

@patch("torchtune.datasets._instruct.load_dataset")
def test_label_no_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is not masked.
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"id": "13818513",
"dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)",
"summary": "Amanda baked cookies and will bring Jerry some tomorrow.",
},
]

samsum_ds = samsum_dataset(tokenizer=tokenizer)
input, labels = samsum_ds[0]

assert len(input) == len(labels)
assert labels[-1] == tokenizer.eos_id
assert input[0] == tokenizer.bos_id
assert CROSS_ENTROPY_IGNORE_IDX not in labels

@patch("torchtune.datasets._instruct.load_dataset")
def test_label_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is masked.
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"id": "13818513",
"dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)",
"summary": "Amanda baked cookies and will bring Jerry some tomorrow.",
},
]

samsum_ds = samsum_dataset(tokenizer=tokenizer, train_on_input=False)

# Extract the prompt and tokenize it; we'll need this to test whether we're masking the
# input correctly
sample = samsum_ds._data[0]
prompt = samsum_ds.template.format(sample=sample)
encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False)

# Generate the input and labels
input, labels = samsum_ds[0]

assert len(input) == len(labels)
assert labels[-1] == tokenizer.eos_id
assert input[0] == tokenizer.bos_id
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt)
3 changes: 2 additions & 1 deletion torchtune/data/_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def format(
sample (Mapping[str, Any]): a single data sample with various fields
column_map (Optional[Dict[str, str]]): a mapping from the expected
placeholder names in the template to the column names in the sample.
If None, assume these are identical.
If None, assume these are identical. Note: if the sample output is not named
as "output" in the dataset, you always need to map it to "output" in column_map.

Returns:
The formatted prompt
Expand Down
4 changes: 4 additions & 0 deletions torchtune/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
# LICENSE file in the root directory of this source tree.

from torchtune.datasets._alpaca import alpaca_dataset
from torchtune.datasets._grammar import grammar_dataset
from torchtune.datasets._instruct import InstructDataset
from torchtune.datasets._samsum import samsum_dataset
from torchtune.datasets._slimorca import SlimOrcaDataset

__all__ = [
"alpaca_dataset",
"grammar_dataset",
"samsum_dataset",
"SlimOrcaDataset",
"InstructDataset",
]
55 changes: 55 additions & 0 deletions torchtune/datasets/_grammar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.data import GrammarErrorCorrectionTemplate
from torchtune.datasets._instruct import InstructDataset
from torchtune.modules import Tokenizer


def grammar_dataset(
tokenizer: Tokenizer,
train_on_input: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment below

) -> InstructDataset:
"""
Support for the Grammar dataset and its variants from HuggingFace Datasets.
https://huggingface.co/datasets/liweili/c4_200m

Data input format: https://huggingface.co/datasets/liweili/c4_200m#description

The prompt template is created from llama_recipes codebase:
https:/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py#L50

where `input` and `output` are fields from the dataset.

Masking of the prompt during training is controlled by the `train_on_input` flag, which is
set to `True` by default
- If `train_on_input` is True, the prompt is used during training and
contributes to the loss.
- If `train_on_input` is False, the prompt is masked out (tokens replaced with -100)

Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is True.

Returns:
InstructDataset: dataset configured with Grammar source data and template


Example:
>>> grammar_ds = grammar_dataset(tokenizer=tokenizer)
>>> for batch in Dataloader(grammar_ds, batch_size=8):
>>> print(f"Batch size: {len(batch)}")
>>> Batch size: 8
"""

return InstructDataset(
tokenizer=tokenizer,
source="liweili/c4_200m",
template=GrammarErrorCorrectionTemplate(),
column_map={"sentence": "input"},
train_on_input=train_on_input,
split="train",
)
6 changes: 5 additions & 1 deletion torchtune/datasets/_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Tuple[List[int], List[in
transformed_sample = self._transform(sample) if self._transform else sample

prompt = self.template.format(transformed_sample, self._column_map)
key_output = self._column_map["output"] if self._column_map else "output"
key_output = (
self._column_map["output"]
if self._column_map and "output" in self._column_map
else "output"
)
prompt_with_response = prompt + sample[key_output]

encoded_prompt = self._tokenizer.encode(
Expand Down
55 changes: 55 additions & 0 deletions torchtune/datasets/_samsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.data import SummarizeTemplate
from torchtune.datasets._instruct import InstructDataset
from torchtune.modules import Tokenizer


def samsum_dataset(
tokenizer: Tokenizer,
train_on_input: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think by default we want to leave train_on_input on False, as I've seen most datasets do mask out the prompt when computing loss (except alpaca)

) -> InstructDataset:
"""
Support for the Summarize dataset and its variants from HuggingFace Datasets.
https://huggingface.co/datasets/samsum

Data input format: https://huggingface.co/datasets/samsum#data-fields

The prompt template is created from llama_recipes codebase:
https:/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/samsum_dataset.py#L13

where `dialogue` and `summary` are fields from the dataset.

Masking of the prompt during training is controlled by the `train_on_input` flag, which is
set to `True` by default
- If `train_on_input` is True, the prompt is used during training and
contributes to the loss.
- If `train_on_input` is False, the prompt is masked out (tokens replaced with -100)

Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is True.

Returns:
InstructDataset: dataset configured with Summarization source data and template


Example:
>>> samsum_ds = samsum_dataset(tokenizer=tokenizer)
>>> for batch in Dataloader(samsum_ds, batch_size=8):
>>> print(f"Batch size: {len(batch)}")
>>> Batch size: 8
"""

return InstructDataset(
tokenizer=tokenizer,
source="samsum",
template=SummarizeTemplate(),
column_map={"output": "summary"},
train_on_input=train_on_input,
split="train",
)
Loading