From 0da066c88f98104fbb047e67d462b56154d7f027 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 20 Mar 2024 17:38:06 -0700 Subject: [PATCH 1/3] add other instruction dataset builders --- docs/source/api_ref_datasets.rst | 2 + .../datasets/test_grammar_dataset.py | 77 +++++++++++++++++++ .../torchtune/datasets/test_samsum_dataset.py | 77 +++++++++++++++++++ torchtune/data/_templates.py | 3 +- torchtune/datasets/__init__.py | 4 + torchtune/datasets/_grammar.py | 55 +++++++++++++ torchtune/datasets/_instruct.py | 6 +- torchtune/datasets/_samsum.py | 55 +++++++++++++ 8 files changed, 277 insertions(+), 2 deletions(-) create mode 100644 tests/torchtune/datasets/test_grammar_dataset.py create mode 100644 tests/torchtune/datasets/test_samsum_dataset.py create mode 100644 torchtune/datasets/_grammar.py create mode 100644 torchtune/datasets/_samsum.py diff --git a/docs/source/api_ref_datasets.rst b/docs/source/api_ref_datasets.rst index 6435039eb..d6199623d 100644 --- a/docs/source/api_ref_datasets.rst +++ b/docs/source/api_ref_datasets.rst @@ -11,4 +11,6 @@ torchtune.datasets :nosignatures: alpaca_dataset + grammar_dataset + samsum_dataset SlimOrcaDataset diff --git a/tests/torchtune/datasets/test_grammar_dataset.py b/tests/torchtune/datasets/test_grammar_dataset.py new file mode 100644 index 000000000..fcc0e63f0 --- /dev/null +++ b/tests/torchtune/datasets/test_grammar_dataset.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from unittest.mock import patch + +import pytest + +from tests.test_utils import get_assets_path +from torchtune.datasets._common import CROSS_ENTROPY_IGNORE_IDX + +from torchtune.datasets._grammar import grammar_dataset +from torchtune.modules.tokenizer import Tokenizer + + +class TestGrammarDataset: + @pytest.fixture + def tokenizer(self): + # m.model is a pretrained Sentencepiece model using the following command: + # spm.SentencePieceTrainer.train('--input= --model_prefix=m --vocab_size=2000') + return Tokenizer.from_file(str(get_assets_path() / "m.model")) + + @patch("torchtune.datasets._instruct.load_dataset") + def test_label_no_masking(self, load_dataset, tokenizer): + """ + Test whether the input and the labels are correctly created when the input is not masked. + """ + + # mock the call to HF datasets + load_dataset.return_value = [ + { + "input": "Bitcoin is for $7,094 this morning, which CoinDesk says.", + "output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.", + } + ] + + grammar_ds = grammar_dataset(tokenizer=tokenizer) + input, labels = grammar_ds[0] + + assert len(input) == len(labels) + assert labels[-1] == tokenizer.eos_id + assert input[0] == tokenizer.bos_id + assert CROSS_ENTROPY_IGNORE_IDX not in labels + + @patch("torchtune.datasets._instruct.load_dataset") + def test_label_masking(self, load_dataset, tokenizer): + """ + Test whether the input and the labels are correctly created when the input is masked. + """ + + # mock the call to HF datasets + load_dataset.return_value = [ + { + "input": "Bitcoin is for $7,094 this morning, which CoinDesk says.", + "output": "Bitcoin goes for $7,094 this morning, according to CoinDesk.", + } + ] + + grammar_ds = grammar_dataset(tokenizer=tokenizer, train_on_input=False) + + # Extract the prompt and tokenize it; we'll need this to test whether we're masking the + # input correctly + sample = grammar_ds._data[0] + prompt = grammar_ds.template.format( + sample=sample, column_map={"sentence": "input"} + ) + encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False) + + # Generate the input and labels + input, labels = grammar_ds[0] + + assert len(input) == len(labels) + assert labels[-1] == tokenizer.eos_id + assert input[0] == tokenizer.bos_id + assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt) diff --git a/tests/torchtune/datasets/test_samsum_dataset.py b/tests/torchtune/datasets/test_samsum_dataset.py new file mode 100644 index 000000000..78ac57a69 --- /dev/null +++ b/tests/torchtune/datasets/test_samsum_dataset.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from unittest.mock import patch + +import pytest + +from tests.test_utils import get_assets_path +from torchtune.datasets._common import CROSS_ENTROPY_IGNORE_IDX + +from torchtune.datasets._samsum import samsum_dataset +from torchtune.modules.tokenizer import Tokenizer + + +class TestSamsumDataset: + @pytest.fixture + def tokenizer(self): + # m.model is a pretrained Sentencepiece model using the following command: + # spm.SentencePieceTrainer.train('--input= --model_prefix=m --vocab_size=2000') + return Tokenizer.from_file(str(get_assets_path() / "m.model")) + + @patch("torchtune.datasets._instruct.load_dataset") + def test_label_no_masking(self, load_dataset, tokenizer): + """ + Test whether the input and the labels are correctly created when the input is not masked. + """ + + # mock the call to HF datasets + load_dataset.return_value = [ + { + "id": "13818513", + "dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", + "summary": "Amanda baked cookies and will bring Jerry some tomorrow.", + }, + ] + + samsum_ds = samsum_dataset(tokenizer=tokenizer) + input, labels = samsum_ds[0] + + assert len(input) == len(labels) + assert labels[-1] == tokenizer.eos_id + assert input[0] == tokenizer.bos_id + assert CROSS_ENTROPY_IGNORE_IDX not in labels + + @patch("torchtune.datasets._instruct.load_dataset") + def test_label_masking(self, load_dataset, tokenizer): + """ + Test whether the input and the labels are correctly created when the input is masked. + """ + + # mock the call to HF datasets + load_dataset.return_value = [ + { + "id": "13818513", + "dialogue": "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)", + "summary": "Amanda baked cookies and will bring Jerry some tomorrow.", + }, + ] + + samsum_ds = samsum_dataset(tokenizer=tokenizer, train_on_input=False) + + # Extract the prompt and tokenize it; we'll need this to test whether we're masking the + # input correctly + sample = samsum_ds._data[0] + prompt = samsum_ds.template.format(sample=sample) + encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False) + + # Generate the input and labels + input, labels = samsum_ds[0] + + assert len(input) == len(labels) + assert labels[-1] == tokenizer.eos_id + assert input[0] == tokenizer.bos_id + assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt) diff --git a/torchtune/data/_templates.py b/torchtune/data/_templates.py index bd25d2eca..783d1554d 100644 --- a/torchtune/data/_templates.py +++ b/torchtune/data/_templates.py @@ -27,7 +27,8 @@ def format( sample (Mapping[str, Any]): a single data sample with various fields column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names in the template to the column names in the sample. - If None, assume these are identical. + If None, assume these are identical. Note: if the sample output is not named + as "output" in the dataset, you always need to map it to "output" in column_map. Returns: The formatted prompt diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index 4fecc9c50..787eb558a 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -5,11 +5,15 @@ # LICENSE file in the root directory of this source tree. from torchtune.datasets._alpaca import alpaca_dataset +from torchtune.datasets._grammar import grammar_dataset from torchtune.datasets._instruct import InstructDataset +from torchtune.datasets._samsum import samsum_dataset from torchtune.datasets._slimorca import SlimOrcaDataset __all__ = [ "alpaca_dataset", + "grammar_dataset", + "samsum_dataset", "SlimOrcaDataset", "InstructDataset", ] diff --git a/torchtune/datasets/_grammar.py b/torchtune/datasets/_grammar.py new file mode 100644 index 000000000..0d4bec2eb --- /dev/null +++ b/torchtune/datasets/_grammar.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.data import GrammarErrorCorrectionTemplate +from torchtune.datasets._instruct import InstructDataset +from torchtune.modules import Tokenizer + + +def grammar_dataset( + tokenizer: Tokenizer, + train_on_input: bool = True, +) -> InstructDataset: + """ + Support for the Grammar dataset and its variants from HuggingFace Datasets. + https://huggingface.co/datasets/liweili/c4_200m + + Data input format: https://huggingface.co/datasets/liweili/c4_200m#description + + The prompt template is created from llama_recipes codebase: + https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py#L50 + + where `input` and `output` are fields from the dataset. + + Masking of the prompt during training is controlled by the `train_on_input` flag, which is + set to `True` by default + - If `train_on_input` is True, the prompt is used during training and + contributes to the loss. + - If `train_on_input` is False, the prompt is masked out (tokens replaced with -100) + + Args: + tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is True. + + Returns: + InstructDataset: dataset configured with Grammar source data and template + + + Example: + >>> grammar_ds = grammar_dataset(tokenizer=tokenizer) + >>> for batch in Dataloader(grammar_ds, batch_size=8): + >>> print(f"Batch size: {len(batch)}") + >>> Batch size: 8 + """ + + return InstructDataset( + tokenizer=tokenizer, + source="liweili/c4_200m", + template=GrammarErrorCorrectionTemplate(), + column_map={"sentence": "input"}, + train_on_input=train_on_input, + split="train", + ) diff --git a/torchtune/datasets/_instruct.py b/torchtune/datasets/_instruct.py index 2d25c8223..880982b4b 100644 --- a/torchtune/datasets/_instruct.py +++ b/torchtune/datasets/_instruct.py @@ -75,7 +75,11 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Tuple[List[int], List[in transformed_sample = self._transform(sample) if self._transform else sample prompt = self.template.format(transformed_sample, self._column_map) - key_output = self._column_map["output"] if self._column_map else "output" + key_output = ( + self._column_map["output"] + if self._column_map and "output" in self._column_map + else "output" + ) prompt_with_response = prompt + sample[key_output] encoded_prompt = self._tokenizer.encode( diff --git a/torchtune/datasets/_samsum.py b/torchtune/datasets/_samsum.py new file mode 100644 index 000000000..754eb1665 --- /dev/null +++ b/torchtune/datasets/_samsum.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.data import SummarizeTemplate +from torchtune.datasets._instruct import InstructDataset +from torchtune.modules import Tokenizer + + +def samsum_dataset( + tokenizer: Tokenizer, + train_on_input: bool = True, +) -> InstructDataset: + """ + Support for the Summarize dataset and its variants from HuggingFace Datasets. + https://huggingface.co/datasets/samsum + + Data input format: https://huggingface.co/datasets/samsum#data-fields + + The prompt template is created from llama_recipes codebase: + https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/samsum_dataset.py#L13 + + where `dialogue` and `summary` are fields from the dataset. + + Masking of the prompt during training is controlled by the `train_on_input` flag, which is + set to `True` by default + - If `train_on_input` is True, the prompt is used during training and + contributes to the loss. + - If `train_on_input` is False, the prompt is masked out (tokens replaced with -100) + + Args: + tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is True. + + Returns: + InstructDataset: dataset configured with Summarization source data and template + + + Example: + >>> samsum_ds = samsum_dataset(tokenizer=tokenizer) + >>> for batch in Dataloader(samsum_ds, batch_size=8): + >>> print(f"Batch size: {len(batch)}") + >>> Batch size: 8 + """ + + return InstructDataset( + tokenizer=tokenizer, + source="samsum", + template=SummarizeTemplate(), + column_map={"output": "summary"}, + train_on_input=train_on_input, + split="train", + ) From 562b2a81d5c2d899834abc650306584d3c3a4eae Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Thu, 21 Mar 2024 23:44:27 -0700 Subject: [PATCH 2/3] set train_on_input default to false --- torchtune/datasets/_grammar.py | 6 +++--- torchtune/datasets/_samsum.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtune/datasets/_grammar.py b/torchtune/datasets/_grammar.py index 0d4bec2eb..dfd3c2ba2 100644 --- a/torchtune/datasets/_grammar.py +++ b/torchtune/datasets/_grammar.py @@ -11,7 +11,7 @@ def grammar_dataset( tokenizer: Tokenizer, - train_on_input: bool = True, + train_on_input: bool = False, ) -> InstructDataset: """ Support for the Grammar dataset and its variants from HuggingFace Datasets. @@ -25,14 +25,14 @@ def grammar_dataset( where `input` and `output` are fields from the dataset. Masking of the prompt during training is controlled by the `train_on_input` flag, which is - set to `True` by default + set to `False` by default - If `train_on_input` is True, the prompt is used during training and contributes to the loss. - If `train_on_input` is False, the prompt is masked out (tokens replaced with -100) Args: tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. - train_on_input (bool): Whether the model is trained on the prompt or not. Default is True. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. Returns: InstructDataset: dataset configured with Grammar source data and template diff --git a/torchtune/datasets/_samsum.py b/torchtune/datasets/_samsum.py index 754eb1665..3dc7ea819 100644 --- a/torchtune/datasets/_samsum.py +++ b/torchtune/datasets/_samsum.py @@ -11,7 +11,7 @@ def samsum_dataset( tokenizer: Tokenizer, - train_on_input: bool = True, + train_on_input: bool = False, ) -> InstructDataset: """ Support for the Summarize dataset and its variants from HuggingFace Datasets. @@ -25,14 +25,14 @@ def samsum_dataset( where `dialogue` and `summary` are fields from the dataset. Masking of the prompt during training is controlled by the `train_on_input` flag, which is - set to `True` by default + set to `False` by default - If `train_on_input` is True, the prompt is used during training and contributes to the loss. - If `train_on_input` is False, the prompt is masked out (tokens replaced with -100) Args: tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. - train_on_input (bool): Whether the model is trained on the prompt or not. Default is True. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. Returns: InstructDataset: dataset configured with Summarization source data and template From 9f033a8dc7599a75b2e9de1a73c43ba143dc6f53 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Fri, 22 Mar 2024 00:27:07 -0700 Subject: [PATCH 3/3] fix unit test --- tests/torchtune/datasets/test_grammar_dataset.py | 4 ++-- tests/torchtune/datasets/test_samsum_dataset.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/torchtune/datasets/test_grammar_dataset.py b/tests/torchtune/datasets/test_grammar_dataset.py index fcc0e63f0..9c9c0b8cc 100644 --- a/tests/torchtune/datasets/test_grammar_dataset.py +++ b/tests/torchtune/datasets/test_grammar_dataset.py @@ -36,7 +36,7 @@ def test_label_no_masking(self, load_dataset, tokenizer): } ] - grammar_ds = grammar_dataset(tokenizer=tokenizer) + grammar_ds = grammar_dataset(tokenizer=tokenizer, train_on_input=True) input, labels = grammar_ds[0] assert len(input) == len(labels) @@ -58,7 +58,7 @@ def test_label_masking(self, load_dataset, tokenizer): } ] - grammar_ds = grammar_dataset(tokenizer=tokenizer, train_on_input=False) + grammar_ds = grammar_dataset(tokenizer=tokenizer) # Extract the prompt and tokenize it; we'll need this to test whether we're masking the # input correctly diff --git a/tests/torchtune/datasets/test_samsum_dataset.py b/tests/torchtune/datasets/test_samsum_dataset.py index 78ac57a69..71eea5993 100644 --- a/tests/torchtune/datasets/test_samsum_dataset.py +++ b/tests/torchtune/datasets/test_samsum_dataset.py @@ -37,7 +37,7 @@ def test_label_no_masking(self, load_dataset, tokenizer): }, ] - samsum_ds = samsum_dataset(tokenizer=tokenizer) + samsum_ds = samsum_dataset(tokenizer=tokenizer, train_on_input=True) input, labels = samsum_ds[0] assert len(input) == len(labels) @@ -60,7 +60,7 @@ def test_label_masking(self, load_dataset, tokenizer): }, ] - samsum_ds = samsum_dataset(tokenizer=tokenizer, train_on_input=False) + samsum_ds = samsum_dataset(tokenizer=tokenizer) # Extract the prompt and tokenize it; we'll need this to test whether we're masking the # input correctly