Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor datasets and tokenizer #624

Merged
merged 11 commits into from
Apr 2, 2024
43 changes: 42 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
from typing import Any, Dict, Generator, Optional, TextIO, Tuple, Union
from typing import Any, Dict, Generator, List, Optional, TextIO, Tuple, Union

import pytest

import torch
from torch import nn
from torchtune.data._types import Message
from torchtune.data._utils import truncate


skip_if_cuda_not_available = unittest.skipIf(
Expand Down Expand Up @@ -50,6 +52,45 @@ def eos_id(self):
def bos_id(self):
return 0

def tokenize_messages(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think of an alternative solution to this, because we'll have to update this every time the real tokenize messages is updated. Can we do a more stripped down approach for testing purposes or is this the most barebones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree. The main thing I wanted here was to guarantee I could replicate the performance in the existing unit test using the same logic. We can definitely use a simpler method but will have to change the expected values. (Really I should just add a test for tokenize_messages on the tokenizer, then we can use something simple here and still be confident it's working as expected.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you added the tokenize_messages test - should we simplify here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I do like explicitly testing fetching a sample from ChatDataset though. I have another idea to simplify the code here 😃

self, messages: List[Message], max_seq_len: Optional[int] = None
):
start_of_turn = True
end_of_turn = False
tokenized_messages = []
mask = []
for message in messages:
# If assistant message, this is the end of a turn
end_of_turn = message.role == "assistant"

# Prepend BOS on start of new turns
if start_of_turn:
tokenized_messages.append(self.bos_id)
mask.append(message.masked)
start_of_turn = False

# Tokenize current message, append with masks
tokens = self.encode(
message.content,
add_bos=False,
add_eos=False,
)
tokenized_messages.extend(tokens)
mask.extend([message.masked] * len(tokens))

# If assistant message, append EOS at end
if end_of_turn:
tokenized_messages.append(self.eos_id)
mask.append(message.masked)
end_of_turn = False
start_of_turn = True

if max_seq_len:
tokenized_messages = truncate(tokenized_messages, max_seq_len, self.eos_id)
mask = truncate(mask, max_seq_len, False)

return tokenized_messages, mask


def get_assets_path():
return Path(__file__).parent / "assets"
Expand Down
130 changes: 130 additions & 0 deletions tests/torchtune/data/test_chat_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
from torchtune.data import ChatMLFormat, Llama2ChatFormat, Message, MistralChatFormat

# Taken from Open-Orca/SlimOrca-Dedup on HuggingFace:
# https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
CHAT_SAMPLE = [
Message(
role="system",
content="You are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. "
"While performing the task think step-by-step and justify your steps.",
),
Message(
role="user",
content="Please briefly summarize this news article:\n\nAOL.com Video - "
"Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your "
"8-year-old drive your car? How about on an icy road? Well one father in "
"Russia did just that, and recorded the entire thing. To her credit, the "
"child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , "
"caught on camera , child driver , pix11\n\nSummary:",
),
Message(
role="assistant",
content="A father in Russia allowed his 8-year-old child to drive his car "
"on an icy road and recorded the event. The child appeared to be handling the "
"situation well, showcasing their driving skills despite the challenging conditions.",
),
]


def _assert_dialogue_equal(actual, expected):
assert len(actual) == len(expected)
for i in range(len(actual)):
assert actual[i].role == expected[i].role
assert actual[i].content == expected[i].content


class TestLlama2ChatFormat:
expected_dialogue = [
Message(
role="user",
content="[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. While performing "
"the task think step-by-step and justify your steps.\n<</SYS>>\n\nPlease "
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary: [/INST] ",
),
Message(
role="assistant",
content="A father in Russia allowed his 8-year-old child to drive his car on an "
"icy road and recorded the event. The child appeared to be handling the situation well, "
"showcasing their driving skills despite the challenging conditions.",
),
]

def test_format(self):
actual = Llama2ChatFormat.format(CHAT_SAMPLE)
_assert_dialogue_equal(actual, self.expected_dialogue)


class TestMistralChatFormat:
expected_dialogue = [
Message(
role="user",
content="[INST] Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary: [/INST] ",
),
Message(
role="assistant",
content="A father in Russia allowed his 8-year-old child to drive his car on an "
"icy road and recorded the event. The child appeared to be handling the situation well, "
"showcasing their driving skills despite the challenging conditions.",
),
]

def test_format(self):
no_system_sample = CHAT_SAMPLE[1:]
actual = MistralChatFormat.format(no_system_sample)
_assert_dialogue_equal(actual, self.expected_dialogue)

def test_format_with_system_prompt_raises(self):
with pytest.raises(
ValueError, match="System prompts are not supported in MistralChatFormat"
):
_ = MistralChatFormat.format(CHAT_SAMPLE)


class TestChatMLFormat:
expected_dialogue = [
Message(
role="system",
content="<|im_start|>system\nYou are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. While performing "
"the task think step-by-step and justify your steps.<|im_end|>\n",
),
Message(
role="user",
content="<|im_start|>user\nPlease "
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary:<|im_end|>\n",
),
Message(
role="assistant",
content="<|im_start|>assistant\nA father in Russia allowed his 8-year-old child to drive his car on an "
"icy road and recorded the event. The child appeared to be handling the situation well, "
"showcasing their driving skills despite the challenging conditions.<|im_end|>",
),
]

def test_format(self):
actual = ChatMLFormat.format(CHAT_SAMPLE)
_assert_dialogue_equal(actual, self.expected_dialogue)
83 changes: 10 additions & 73 deletions tests/torchtune/data/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,84 +4,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from tests.test_utils import DummyTokenizer
from torchtune.data import tokenize_prompt_and_response, truncate
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX


def test_tokenize_prompt_and_response():
tokenizer = DummyTokenizer()
prompt = "Instruction:\nThis is an instruction.\n\nInput:\nThis is an input.\n\nResponse: "
response = "I always know what I'm doing, do you?"
prompt_length = 12
expected_tokenized_prompt = [
0,
12,
4,
2,
2,
12,
6,
4,
2,
2,
6,
9,
1,
6,
4,
4,
3,
6,
2,
4,
-1,
]
expected_tokenized_label = [CROSS_ENTROPY_IGNORE_IDX] * prompt_length + [
1,
6,
4,
4,
3,
6,
2,
4,
-1,
]

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
tokenizer, prompt, response
)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_label

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
tokenizer, prompt, response, train_on_input=True
)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_prompt
from torchtune.data import truncate


def test_truncate():
prompt_tokens = [1, 2, 3, 4, -1]
label_tokens = [1, 2, 3, 4, -1]
tokens = [1, 2, 3, 4, -1]

# Test no truncation
truncated_prompt_tokens, truncated_label_tokens = truncate(
tokenizer=DummyTokenizer(),
prompt_tokens=prompt_tokens,
label_tokens=label_tokens,
truncated_tokens = truncate(
tokens=tokens,
max_seq_len=5,
eos_id=-1,
)
assert truncated_prompt_tokens == prompt_tokens
assert truncated_label_tokens == label_tokens
assert truncated_tokens == tokens

# Test truncated
truncated_prompt_tokens, truncated_label_tokens = truncate(
tokenizer=DummyTokenizer(),
prompt_tokens=prompt_tokens,
label_tokens=label_tokens,
max_seq_len=4,
)
assert truncated_prompt_tokens == [1, 2, 3, -1]
assert truncated_label_tokens == [1, 2, 3, -1]
masks = [True, True, False, True, False]
# Test truncated mask
truncated_masks = truncate(tokens=masks, max_seq_len=4, eos_id=False)
assert truncated_masks == [True, True, False, False]
Loading
Loading