Skip to content

Commit

Permalink
Refactor datasets and tokenizer (#624)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Apr 2, 2024
1 parent 0a82ea4 commit 8183b42
Show file tree
Hide file tree
Showing 23 changed files with 984 additions and 821 deletions.
7 changes: 5 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
from torch import nn

from torchtune.modules import Tokenizer

skip_if_cuda_not_available = unittest.skipIf(
not torch.cuda.is_available(), "CUDA is not available"
Expand All @@ -31,8 +31,11 @@
"llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt",
}

# Inherit from tokenizer class to reuse its tokenize_messages method
class DummyTokenizer(Tokenizer):
def __init__(self):
self.encodes_whitespace = False

class DummyTokenizer:
def encode(self, text, add_bos=True, add_eos=True, **kwargs):
words = text.split()
tokens = [len(word) for word in words]
Expand Down
32 changes: 0 additions & 32 deletions tests/torchtune/config/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
import pytest
from torchtune.config._utils import (
_get_component_from_path,
_get_template,
_merge_yaml_and_cli_args,
InstantiationError,
)
from torchtune.data import AlpacaInstructTemplate
from torchtune.utils.argparse import TuneRecipeArgumentParser

_CONFIG = {
Expand Down Expand Up @@ -109,33 +107,3 @@ def test_merge_yaml_and_cli_args(self, mock_load):
ValueError, match="Command-line overrides must be in the form of key=value"
):
_ = _merge_yaml_and_cli_args(yaml_args, cli_args)

def test_get_template(self):
# Test valid template class
template = _get_template("AlpacaInstructTemplate")
assert isinstance(template, AlpacaInstructTemplate)

# Test invalid template class
with pytest.raises(
ValueError,
match="Must be a PromptTemplate class or a string with placeholders.",
):
_ = _get_template("InvalidTemplate")

# Test valid template strings
valid_templates = [
"Instruction: {instruction}\nInput: {input}",
"Instruction: {instruction}",
"{a}",
]
for template in valid_templates:
assert _get_template(template) == template

# Test invalid template strings
invalid_templates = ["hello", "{}", "a}{b"]
for template in invalid_templates:
with pytest.raises(
ValueError,
match="Must be a PromptTemplate class or a string with placeholders.",
):
_ = _get_template(template)
130 changes: 130 additions & 0 deletions tests/torchtune/data/test_chat_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
from torchtune.data import ChatMLFormat, Llama2ChatFormat, Message, MistralChatFormat

# Taken from Open-Orca/SlimOrca-Dedup on HuggingFace:
# https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
CHAT_SAMPLE = [
Message(
role="system",
content="You are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. "
"While performing the task think step-by-step and justify your steps.",
),
Message(
role="user",
content="Please briefly summarize this news article:\n\nAOL.com Video - "
"Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your "
"8-year-old drive your car? How about on an icy road? Well one father in "
"Russia did just that, and recorded the entire thing. To her credit, the "
"child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , "
"caught on camera , child driver , pix11\n\nSummary:",
),
Message(
role="assistant",
content="A father in Russia allowed his 8-year-old child to drive his car "
"on an icy road and recorded the event. The child appeared to be handling the "
"situation well, showcasing their driving skills despite the challenging conditions.",
),
]


def _assert_dialogue_equal(actual, expected):
assert len(actual) == len(expected)
for i in range(len(actual)):
assert actual[i].role == expected[i].role
assert actual[i].content == expected[i].content


class TestLlama2ChatFormat:
expected_dialogue = [
Message(
role="user",
content="[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. While performing "
"the task think step-by-step and justify your steps.\n<</SYS>>\n\nPlease "
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary: [/INST] ",
),
Message(
role="assistant",
content="A father in Russia allowed his 8-year-old child to drive his car on an "
"icy road and recorded the event. The child appeared to be handling the situation well, "
"showcasing their driving skills despite the challenging conditions.",
),
]

def test_format(self):
actual = Llama2ChatFormat.format(CHAT_SAMPLE)
_assert_dialogue_equal(actual, self.expected_dialogue)


class TestMistralChatFormat:
expected_dialogue = [
Message(
role="user",
content="[INST] Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary: [/INST] ",
),
Message(
role="assistant",
content="A father in Russia allowed his 8-year-old child to drive his car on an "
"icy road and recorded the event. The child appeared to be handling the situation well, "
"showcasing their driving skills despite the challenging conditions.",
),
]

def test_format(self):
no_system_sample = CHAT_SAMPLE[1:]
actual = MistralChatFormat.format(no_system_sample)
_assert_dialogue_equal(actual, self.expected_dialogue)

def test_format_with_system_prompt_raises(self):
with pytest.raises(
ValueError, match="System prompts are not supported in MistralChatFormat"
):
_ = MistralChatFormat.format(CHAT_SAMPLE)


class TestChatMLFormat:
expected_dialogue = [
Message(
role="system",
content="<|im_start|>system\nYou are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. While performing "
"the task think step-by-step and justify your steps.<|im_end|>\n",
),
Message(
role="user",
content="<|im_start|>user\nPlease "
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary:<|im_end|>\n",
),
Message(
role="assistant",
content="<|im_start|>assistant\nA father in Russia allowed his 8-year-old child to drive his car on an "
"icy road and recorded the event. The child appeared to be handling the situation well, "
"showcasing their driving skills despite the challenging conditions.<|im_end|>",
),
]

def test_format(self):
actual = ChatMLFormat.format(CHAT_SAMPLE)
_assert_dialogue_equal(actual, self.expected_dialogue)
83 changes: 10 additions & 73 deletions tests/torchtune/data/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,84 +4,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from tests.test_utils import DummyTokenizer
from torchtune.data import tokenize_prompt_and_response, truncate
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX


def test_tokenize_prompt_and_response():
tokenizer = DummyTokenizer()
prompt = "Instruction:\nThis is an instruction.\n\nInput:\nThis is an input.\n\nResponse: "
response = "I always know what I'm doing, do you?"
prompt_length = 12
expected_tokenized_prompt = [
0,
12,
4,
2,
2,
12,
6,
4,
2,
2,
6,
9,
1,
6,
4,
4,
3,
6,
2,
4,
-1,
]
expected_tokenized_label = [CROSS_ENTROPY_IGNORE_IDX] * prompt_length + [
1,
6,
4,
4,
3,
6,
2,
4,
-1,
]

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
tokenizer, prompt, response
)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_label

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
tokenizer, prompt, response, train_on_input=True
)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_prompt
from torchtune.data import truncate


def test_truncate():
prompt_tokens = [1, 2, 3, 4, -1]
label_tokens = [1, 2, 3, 4, -1]
tokens = [1, 2, 3, 4, -1]

# Test no truncation
truncated_prompt_tokens, truncated_label_tokens = truncate(
tokenizer=DummyTokenizer(),
prompt_tokens=prompt_tokens,
label_tokens=label_tokens,
truncated_tokens = truncate(
tokens=tokens,
max_seq_len=5,
eos_id=-1,
)
assert truncated_prompt_tokens == prompt_tokens
assert truncated_label_tokens == label_tokens
assert truncated_tokens == tokens

# Test truncated
truncated_prompt_tokens, truncated_label_tokens = truncate(
tokenizer=DummyTokenizer(),
prompt_tokens=prompt_tokens,
label_tokens=label_tokens,
max_seq_len=4,
)
assert truncated_prompt_tokens == [1, 2, 3, -1]
assert truncated_label_tokens == [1, 2, 3, -1]
masks = [True, True, False, True, False]
# Test truncated mask
truncated_masks = truncate(tokens=masks, max_seq_len=4, eos_id=False)
assert truncated_masks == [True, True, False, False]
Loading

0 comments on commit 8183b42

Please sign in to comment.