-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor datasets and tokenizer #624
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/624
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5603599 with merge base 73647e2 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is overall quite intuitive and lets Instruct and ChatDataset still play specialized roles while utilizing common APIs. I think we've struck a nice balance here.
tests/test_utils.py
Outdated
@@ -50,6 +52,45 @@ def eos_id(self): | |||
def bos_id(self): | |||
return 0 | |||
|
|||
def tokenize_messages( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should think of an alternative solution to this, because we'll have to update this every time the real tokenize messages is updated. Can we do a more stripped down approach for testing purposes or is this the most barebones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree. The main thing I wanted here was to guarantee I could replicate the performance in the existing unit test using the same logic. We can definitely use a simpler method but will have to change the expected values. (Really I should just add a test for tokenize_messages
on the tokenizer, then we can use something simple here and still be confident it's working as expected.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you added the tokenize_messages
test - should we simplify here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I do like explicitly testing fetching a sample from ChatDataset though. I have another idea to simplify the code here 😃
formatted_dialogue = [] | ||
for message in messages: | ||
content = "" | ||
if message.role == "system": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could put system, user, assistant in a dictionary and just key on message.role
|
||
from torchtune.datasets import slimorca_dataset | ||
from torchtune.modules.tokenizer import Tokenizer | ||
|
||
LLAMA_TEMPLATE = Llama2ChatTemplate() | ||
LLAMA_TEMPLATE = Llama2ChatFormat() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be initialized here, just the class name
torchtune/data/_chat_formats.py
Outdated
def format( | ||
cls, | ||
sample: List[Message], | ||
) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to update return to List[Message]
torchtune/data/_templates.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should be deleted?
torchtune/datasets/_chat.py
Outdated
) | ||
labels = list(np.where(np.logical_not(mask), tokens, CROSS_ENTROPY_IGNORE_IDX)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice indeed, but please add comment for mere mortals like me who need to take more than a min to understand this :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made this needlessly complicated so that's my fault
convert_to_dialogue=convert_to_dialogue, | ||
template=_get_template(template), | ||
convert_to_messages=convert_to_messages, | ||
chat_format=chat_format, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs a mapping from str to the actual class pointer, a simpler version of get_template
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just gonna pass the ChatFormat
directly for now, lmk if that makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, then this builder will require a nested component to use from the config. It won't work from the config then. I can update it in a follow up if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will leave here and duck - is the one layer of component instantiation coming in the way of getting stuff done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah let's do in a follow-up. This will only affect custom chat datasets, right? Which we don't have any of yet anyways. Re nested instantiation, imo this is not a sufficient reason to add it.. I think we can find another way here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the string mapping is very simple and uses our existing tools without requiring nested instantiation or impacting UX. Will do a follow-up and we can discuss there.
torchtune/datasets/_slimorca.py
Outdated
convert_to_dialogue=sharegpt_to_llama2_dialogue, | ||
template=Llama2ChatTemplate(), | ||
convert_to_messages=sharegpt_to_llama2_messages, | ||
chat_format=Llama2ChatFormat(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No () here because format is a class method
torchtune/modules/tokenizer.py
Outdated
mask.extend([message.masked] * len(tokens)) | ||
|
||
# Break out early if we reach max_seq_len | ||
if max_seq_len and len(tokens) >= max_seq_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
len(tokenized_messages)?
self._tokenizer, prompt_tokens, label_tokens, self.max_seq_len | ||
messages = self._convert_to_messages(sample) | ||
messages = self.chat_format.format(messages) | ||
tokens, mask = self._tokenizer.tokenize_messages( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is train_on_input used for ChatDataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Gonna pass it to _convert_to_messages, I think it makes most sense there tbh. Basically we want it wherever we are building the messages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mm _convert_to_messages
is provided by the user right? so then we have to enforce that their Callables take in this parameter and the burden is on them to mask appropriately. maybe pass it to chat_format or tokenize_messages and let the user know that if they set it it will overwrite whatever custom masking they set up in _convert_to_messages
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree. Similar to my comment above, can we handle this in a follow-up? I think it's a relatively standalone change not too related to all this tokenizer business
@pytest.mark.parametrize( | ||
"config", ["full_single_device_low_memory", "full_single_device"] | ||
) | ||
# @pytest.mark.parametrize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for some testing, will remove
) -> List[int]: | ||
"""Encode text into token IDs. | ||
|
||
Args: | ||
text (str): The input text to be encoded, unbatched. | ||
add_bos (bool): Whether to prepend BOS to the input, defaults to True. | ||
add_eos (bool): Whether to append EOS to the input, defaults to True. | ||
|
||
trim_leading_whitespace (bool): Whether to trim leading whitespace from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wanna link the discussion you found here so people know why this is a thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought, I think it might be a bit misleading to just link that with no context. I am gonna add my own comment that imo is more relevant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Llama2ChatFormat
and MistralChatFormat
have a lot of whitespaces between messages - have you considered how that interacts with the trimming? should we remove those whitespaces pre-emptively in the ChatFormat classes?
CHAT_SAMPLE = [ | ||
Message( | ||
role="system", | ||
content="You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe can just chunk the string like the expected_dialogue
s below so we can remove the noqa?
@@ -9,7 +9,9 @@ | |||
from torchtune.data._types import Message | |||
|
|||
|
|||
def sharegpt_to_llama2_messages(sample: Mapping[str, Any]) -> List[Message]: | |||
def sharegpt_to_llama2_messages( | |||
sample: Mapping[str, Any], train_on_input: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't know if we should make these transforms require train_on_input
, then we need to enforce a certain API if a user passes in their own transform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree. I'd like to tackle this as a follow-up tbh
torchtune/modules/tokenizer.py
Outdated
self.whitespace_encodings = { | ||
c: self.spm_model.encode(c) for c in WHITESPACE_CHARS | ||
} | ||
self.encodes_whitespace = any(self.whitespace_encodings.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if spm_model fails to encode a character, what is returned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty list. That's why I'm doing this -- to check if whitespaces are encoded as part of the model. In fact prob don't even need to save the encodings dict at all
torchtune/modules/tokenizer.py
Outdated
Returns: | ||
List[int]: The encoded token IDs. | ||
""" | ||
if trim_leading_whitespace: | ||
# Can define our own custom prefix depending on vocab if needed | ||
if not hasattr(self, "prefix"): | ||
self.prefix = prefix or "pre" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is "\n" not a reasonable default?
@@ -63,6 +69,7 @@ def encode( | |||
add_bos: bool = True, | |||
add_eos: bool = True, | |||
trim_leading_whitespace: bool = False, | |||
prefix: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need to assert that trim_leading_whitespace
is True if prefix is set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idk about assert, but maybe warn? Cause we can set prefix and have it just be a no-op. In general I think it is quite rare to explicitly set prefix though, it should not be needed for the canonical Llama2 tokenizer vocab
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK yeah I am gonna leave out the assertion cause tbh prefix is not really something that I expect people to be experimenting with. Lmk if that makes sense
# Can define our own custom prefix depending on vocab if needed | ||
if not hasattr(self, "prefix"): | ||
self.prefix = prefix or "pre" | ||
self.encoded_prefix = self.spm_model.encode( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be outside the if statement if prefix is already set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to above, we do this one time and cache the result.
Returns: | ||
List[int]: The encoded token IDs. | ||
""" | ||
if trim_leading_whitespace: | ||
# Can define our own custom prefix depending on vocab if needed | ||
if not hasattr(self, "prefix"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is self.prefix
set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is basically just caching it the first time we use it, hence why it's set in the following line. If it's already been encoded, no need to keep re-encoding it since it will remain fixed for the life of the program.
trim_leading_whitespace = ( | ||
(not start_of_turn) | ||
and self.encodes_whitespace | ||
and not prev_ends_with_space |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is prev_ends_with_space
defined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L153 as of this version. This is so that we can work with e.g. the grammar format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should throw an error because I don't see it defined on the very first message, it is only defined after. you should initialize it as False (or whichever is the default behavior) above the for loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh god I think it's breaking out of the and early in the test cases so we never hit this. Good catch
tests/test_utils.py
Outdated
@@ -50,6 +52,45 @@ def eos_id(self): | |||
def bos_id(self): | |||
return 0 | |||
|
|||
def tokenize_messages( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you added the tokenize_messages
test - should we simplify here?
torchtune/data/_templates.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete
convert_to_dialogue=convert_to_dialogue, | ||
template=_get_template(template), | ||
convert_to_messages=convert_to_messages, | ||
chat_format=chat_format, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, then this builder will require a nested component to use from the config. It won't work from the config then. I can update it in a follow up if needed
trim_leading_whitespace = ( | ||
(not start_of_turn) | ||
and self.encodes_whitespace | ||
and not prev_ends_with_space |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should throw an error because I don't see it defined on the very first message, it is only defined after. you should initialize it as False (or whichever is the default behavior) above the for loop
@classmethod | ||
@abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL!
raise ValueError( | ||
"System prompts are not supported in MistralChatFormat" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the very noob question, but why is this not a pass through as we do for assistant? Basically, are we saying that this format doesnt support "system" as a role or that it doesn't support the system tags? Or are they the same thing? And if "system" isn't supported then should we be setting system above to None instead of empty string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RdoubleA may have the best answer here tbh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mistral does not support the system role, so if a user passes in a message with a system role, we need to either error out or raise a warning that it will be ignored. See for context: vllm-project/vllm#2080 (comment)
Agreed on setting system to None instead of an empty string.
|
||
class ChatMLFormat(ChatFormat): | ||
""" | ||
OpenAI's Chat Markup Language used by their chat models: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know, OpenAI models support only tiktoken. Are we adding support for that in this PR or in a follow up? Or not yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's leave it for a follow-up if you're good with that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe they use this format with TikToken: https://community.openai.com/t/how-does-chatml-do-the-exact-formatting/80751
The main motivation to add this template is that it is default in HF if the model has no custom template: https:/huggingface/transformers/blob/096f304695f7e7b169b031f7814352e900ad71c4/src/transformers/tokenization_utils_base.py#L1838
pass | ||
|
||
|
||
class AlpacaInstructTemplate(InstructTemplate): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if I misunderstood earlier, but I thought we're converting this to follow a similar structure to the chat formats i.e. tokenize instruction, input and response separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No this is kinda our middle ground for now. This keeps the convenience of string formatting as an option for instruct datasets then casts to messages on the dataset side. Tbh this is one of the kinks I still want to iron out here, we could go all the way on the Message
format but I'm hedging a bit for now cause I know people like the string-formatting of prompts. But yeah this may change in the future
torchtune/data/_transforms.py
Outdated
dialogue.append(Message(role=role, content=content)) | ||
|
||
return dialogue | ||
masked = (role != "assistant") and train_on_input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, maybe I'm misunderstanding this, but shouldn;t this be
masked = (role != "assistant") and not train_on_input
i.e. we don't mask if train_on_input is TRUE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, good catch. I got thrown off by the test_slimorca_dataset
test, which was always asserting that the final token of labels was equal to eos_id, even when max_seq_len was so short that we were not including the assistant message and train_on_input was False. So I think the test was doing the wrong thing in that case (if train_on_input=False and we only have inputs, then everything should be masked, even the EOS token). Lmk if this makes sense to you
torchtune/datasets/_chat.py
Outdated
) | ||
labels = list(np.where(np.logical_not(mask), tokens, CROSS_ENTROPY_IGNORE_IDX)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice indeed, but please add comment for mere mortals like me who need to take more than a min to understand this :)
convert_to_dialogue=convert_to_dialogue, | ||
template=_get_template(template), | ||
convert_to_messages=convert_to_messages, | ||
chat_format=chat_format, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will leave here and duck - is the one layer of component instantiation coming in the way of getting stuff done?
self.encodes_whitespace = any( | ||
[self.spm_model.encode(c) for c in WHITESPACE_CHARS] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no idea what's going on here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should probably add a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But yeah, discussed the purpose of this variable in the second point here
underlying sentencepiece tokenization. Sentencepiece normally prepends | ||
whitespace to any tokenized text, which can cause differences where | ||
encode(s1) + encode(s2) != encode(s1 + s2) due to leading whitespace | ||
added to s2. Default: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So when would I set this to True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to describe in the comment on L138-L141. First, we only run into mismatches due to leading whitespace on the nth split of a string, where n>1. This is because sentencepiece automatically prepends whitespace to a string prior to encoding it, so when we do encode(s1+s2) and encode(s1) + encode(s2), s1 has been prepended with whitespace in both cases, and so the behavior is the same (this is point (a)).
Second, we also only see a mismatch if the tokenizer actually explicitly encodes whitespace, which not all tokenizers do. E.g. if spm_test
corresponds to our tokenizer in tests/assets/m.model
, and spm_llama
corresponds to the usual Llama2 tokenizer, we can actually see that
>>> spm_test.encode(" ", add_bos=False, add_eos=False)
[]
>>> spm_test.encode("\n", add_bos=False, add_eos=False)
[]
>>> spm_llama.encode(" ", add_bos=False, add_eos=False)
[259]
>>> spm_llama.encode("\n", add_bos=False, add_eos=False)
[29871, 13],
so our test tokenizer does not return any tokens when it only sees whitespace. Because of this, we get different behavior in both cases when splitting a single string.
# When the tokenizer doesn't tokenize whitespace, the results match
>>> spm_test.encode("hi\nthere", add_bos=False, add_eos=False)
[476, 70]
>>> spm_test.encode("hi\n", add_bos=False, add_eos=False) + spm_test.encode("there", add_bos=False, add_eos=False)
[476, 70]
# On the regular Llama2 tokenizer, the results do not match
>>> spm_llama.encode("hi\nthere", add_bos=False, add_eos=False)
[7251, 13, 12711]
>>> spm_llama.encode("hi\n", add_bos=False, add_eos=False) + spm_llama.encode("there", add_bos=False, add_eos=False)
[7251, 13, 727]
Finally, there are some prompts that end with " " (e.g. our GrammarErrorCorrectionTemplate
). In this case, the concatenated string actually contains a space, so we do not want to strip it from s2. But we do want to remove it from the preceding string to ensure we don't double-count it. This is the reason for the .rstrip(" ")
in L150.
So these are the three conditions we need to check to determine whether to do this special handling, and they are all checked in tokenize_messages
below.
mask = [] | ||
for message in messages: | ||
# If assistant message, this is the end of a turn | ||
end_of_turn = message.role == "assistant" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm grossly misunderstanding, but how are both start_of_turn and end_of_turn both True here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They shouldn't be. End of turn is just on an assistant message. We set that at the beginning of the iteration, then at the end of the iteration, we set end of turn to False and start of turn to True, since the next iteration will be the start of a turn. Lmk if I'm misunderstanding your question here though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start of turn and end of turn can both be true here if the list of messages the user passes in consists of only a single assistant message. This would not be a valid dialogue, but the BOS and EOS should still be appended correctly. I wonder if we should validate that the dialogue is well-formed here or outside this method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good point. Imo it's not up to the tokenizer to decide this. Let's consider adding a separate utility e.g. validate_messages
that can be called from the dataset class or elsewhere as a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of minor comments, otherwise looks good to go! This was a tough one, thanks for pushing this through. Let's make sure we keep track of the follow-up items that were discussed.
@@ -31,8 +31,11 @@ | |||
"llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt", | |||
} | |||
|
|||
# Inherit from tokenizer class to reuse its tokenize_messages method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you just say inherit 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Best way to reduce copy-pasted code 😉. But for this case I do think it's the right thing, since it still allows us to call tokenize_messages
(e.g. when testing on datasets) with the stripped-down tokenizer but using equivalent logic. The only usage currently is in test_slimorca.py, but I do kinda like that test, so keeping it like this for now.
raise ValueError( | ||
"System prompts are not supported in MistralChatFormat" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mistral does not support the system role, so if a user passes in a message with a system role, we need to either error out or raise a warning that it will be ignored. See for context: vllm-project/vllm#2080 (comment)
Agreed on setting system to None instead of an empty string.
|
||
class ChatMLFormat(ChatFormat): | ||
""" | ||
OpenAI's Chat Markup Language used by their chat models: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe they use this format with TikToken: https://community.openai.com/t/how-does-chatml-do-the-exact-formatting/80751
The main motivation to add this template is that it is default in HF if the model has no custom template: https:/huggingface/transformers/blob/096f304695f7e7b169b031f7814352e900ad71c4/src/transformers/tokenization_utils_base.py#L1838
convert_to_dialogue=convert_to_dialogue, | ||
template=_get_template(template), | ||
convert_to_messages=convert_to_messages, | ||
chat_format=chat_format, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the string mapping is very simple and uses our existing tools without requiring nested instantiation or impacting UX. Will do a follow-up and we can discuss there.
mask = [] | ||
for message in messages: | ||
# If assistant message, this is the end of a turn | ||
end_of_turn = message.role == "assistant" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start of turn and end of turn can both be true here if the list of messages the user passes in consists of only a single assistant message. This would not be a valid dialogue, but the BOS and EOS should still be appended correctly. I wonder if we should validate that the dialogue is well-formed here or outside this method
Note: a lot of the changes to templates here are primarily done by @RdoubleA
Context
Changelog
tokenize_messages
API to our tokenizer. This takes in a list of messages, tokenizes each one individually, then stitches the outputs together with any requisite special tokens, truncation, etc. For SentencePiece tokenizer (which is all we currently support), this is just BOS and EOS, but for tokenizers with more complicated sets of special tokens, this allows us to customize at the tokenizer level without the dataset class having to worry about it.tokenize_messages
. This is more general than our current usage oftokenize_prompt_and_response
.List[Message] -> List[Message]
to better align with the natural format of chat conversations. Instruct templates still operate as string formatters to maintain the ability for simple prompt formatting on instruct datasets.Test plan
Added new tokenizer tests for (a)
tokenize_messages
API, and (b) encoding without leading whitespaceRefactored existing template tests into instruct template tests and chat format tests (again thanks @RdoubleA).
LoRA finetune on slimorca dataset (manually deleting
dataset.use_clean
from the config cause I can't override from CLI properly)