-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor datasets and tokenizer #624
Merged
Merged
Changes from 7 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
b9e0f5f
[RFC] Refactor datasets and tokenizer
ebsmothers 117dc32
add option to remove leading whitespace
ebsmothers 73ef79d
address comments, add tokenizer test, better whitespace handling
ebsmothers a6e24d7
better choice of prefix
ebsmothers 6592526
get tests passing
ebsmothers d71f397
add missing assertion
ebsmothers 5a66885
address comments
ebsmothers dbc646e
address comments
ebsmothers c29e8cb
missed dead import
ebsmothers f9ec55d
bug fix and some cleanup
ebsmothers 5603599
address remaining comments
ebsmothers File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
from torchtune.data import ChatMLFormat, Llama2ChatFormat, Message, MistralChatFormat | ||
|
||
# Taken from Open-Orca/SlimOrca-Dedup on HuggingFace: | ||
# https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup | ||
CHAT_SAMPLE = [ | ||
Message( | ||
role="system", | ||
content="You are an AI assistant. User will you give you a task. " | ||
"Your goal is to complete the task as faithfully as you can. " | ||
"While performing the task think step-by-step and justify your steps.", | ||
), | ||
Message( | ||
role="user", | ||
content="Please briefly summarize this news article:\n\nAOL.com Video - " | ||
"Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your " | ||
"8-year-old drive your car? How about on an icy road? Well one father in " | ||
"Russia did just that, and recorded the entire thing. To her credit, the " | ||
"child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , " | ||
"caught on camera , child driver , pix11\n\nSummary:", | ||
), | ||
Message( | ||
role="assistant", | ||
content="A father in Russia allowed his 8-year-old child to drive his car " | ||
"on an icy road and recorded the event. The child appeared to be handling the " | ||
"situation well, showcasing their driving skills despite the challenging conditions.", | ||
), | ||
] | ||
|
||
|
||
def _assert_dialogue_equal(actual, expected): | ||
assert len(actual) == len(expected) | ||
for i in range(len(actual)): | ||
assert actual[i].role == expected[i].role | ||
assert actual[i].content == expected[i].content | ||
|
||
|
||
class TestLlama2ChatFormat: | ||
expected_dialogue = [ | ||
Message( | ||
role="user", | ||
content="[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. " | ||
"Your goal is to complete the task as faithfully as you can. While performing " | ||
"the task think step-by-step and justify your steps.\n<</SYS>>\n\nPlease " | ||
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old " | ||
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? " | ||
"How about on an icy road? Well one father in Russia did just that, and recorded " | ||
"the entire thing. To her credit, the child seemed to be doing a great job. " | ||
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n" | ||
"Summary: [/INST] ", | ||
), | ||
Message( | ||
role="assistant", | ||
content="A father in Russia allowed his 8-year-old child to drive his car on an " | ||
"icy road and recorded the event. The child appeared to be handling the situation well, " | ||
"showcasing their driving skills despite the challenging conditions.", | ||
), | ||
] | ||
|
||
def test_format(self): | ||
actual = Llama2ChatFormat.format(CHAT_SAMPLE) | ||
_assert_dialogue_equal(actual, self.expected_dialogue) | ||
|
||
|
||
class TestMistralChatFormat: | ||
expected_dialogue = [ | ||
Message( | ||
role="user", | ||
content="[INST] Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old " | ||
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? " | ||
"How about on an icy road? Well one father in Russia did just that, and recorded " | ||
"the entire thing. To her credit, the child seemed to be doing a great job. " | ||
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n" | ||
"Summary: [/INST] ", | ||
), | ||
Message( | ||
role="assistant", | ||
content="A father in Russia allowed his 8-year-old child to drive his car on an " | ||
"icy road and recorded the event. The child appeared to be handling the situation well, " | ||
"showcasing their driving skills despite the challenging conditions.", | ||
), | ||
] | ||
|
||
def test_format(self): | ||
no_system_sample = CHAT_SAMPLE[1:] | ||
actual = MistralChatFormat.format(no_system_sample) | ||
_assert_dialogue_equal(actual, self.expected_dialogue) | ||
|
||
def test_format_with_system_prompt_raises(self): | ||
with pytest.raises( | ||
ValueError, match="System prompts are not supported in MistralChatFormat" | ||
): | ||
_ = MistralChatFormat.format(CHAT_SAMPLE) | ||
|
||
|
||
class TestChatMLFormat: | ||
expected_dialogue = [ | ||
Message( | ||
role="system", | ||
content="<|im_start|>system\nYou are an AI assistant. User will you give you a task. " | ||
"Your goal is to complete the task as faithfully as you can. While performing " | ||
"the task think step-by-step and justify your steps.<|im_end|>\n", | ||
), | ||
Message( | ||
role="user", | ||
content="<|im_start|>user\nPlease " | ||
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old " | ||
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? " | ||
"How about on an icy road? Well one father in Russia did just that, and recorded " | ||
"the entire thing. To her credit, the child seemed to be doing a great job. " | ||
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n" | ||
"Summary:<|im_end|>\n", | ||
), | ||
Message( | ||
role="assistant", | ||
content="<|im_start|>assistant\nA father in Russia allowed his 8-year-old child to drive his car on an " | ||
"icy road and recorded the event. The child appeared to be handling the situation well, " | ||
"showcasing their driving skills despite the challenging conditions.<|im_end|>", | ||
), | ||
] | ||
|
||
def test_format(self): | ||
actual = ChatMLFormat.format(CHAT_SAMPLE) | ||
_assert_dialogue_equal(actual, self.expected_dialogue) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should think of an alternative solution to this, because we'll have to update this every time the real tokenize messages is updated. Can we do a more stripped down approach for testing purposes or is this the most barebones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree. The main thing I wanted here was to guarantee I could replicate the performance in the existing unit test using the same logic. We can definitely use a simpler method but will have to change the expected values. (Really I should just add a test for
tokenize_messages
on the tokenizer, then we can use something simple here and still be confident it's working as expected.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you added the
tokenize_messages
test - should we simplify here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I do like explicitly testing fetching a sample from ChatDataset though. I have another idea to simplify the code here 😃