-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Datasets and data utils in TorchTune #493
Conversation
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
Links: | ||
- Load_dataset: https://huggingface.co/docs/datasets/v1.12.0/loading.html | ||
- Cloud storage: https://huggingface.co/docs/datasets/en/filesystems | ||
- Streaming: https://huggingface.co/docs/datasets/en/stream |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https:/mosaicml/streaming is also a high quality streaming dataset library that also handles dataset mixing ratios. The built-in streaming in HF's dataset library is not very robust and efficiently processing IterableDatasets at runtime is a consideration that should be addressed. For example, I've seen that the processing datasets during training leads to VRAM memory leaks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now we've made streaming datasets a P1, or at least a quick follow up after the MVP launch. How important would you say it is to have support for both streaming and iterable datasets out of the gate?
processing datasets during training leads to VRAM memory leaks.
We should definitely be wary of this especially for sample packing and any heavy transforms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
processing datasets during training leads to VRAM memory leaks.
Curious to know if this occured for text only datasets too or only in multimodal datasets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't experimented much with multimodal datasets yet.
raw text corpus. The awesome part is that you can combine multiple datasets all | ||
from the config (I think) and specify their template. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is generally correct for the most common dataset structures. There is enough coverage for dataset structure -> chosen prompt format that it is possible to combine several datasets at once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious to learn whether Axolotl supports a separate prompt format for each dataset? Also how are the datasets combined - are the samples in a batch interleaved or is it from one dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
each dataset can have it's own prompt format and after tokenization, we concatenate and shuffle the tokenized datasets.
|
||
## TuneDataset | ||
The general flow of loading a dataset from data file to tokenized prompt and label consists of: | ||
- (optional) Packing the dataset offline and caching it / saving it for use in current and future runs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Offline packing leads to fixed "shuffling" for a given batch between epochs, while this is partially mitigated with larger batch sizes, I've found most users had asked if we could shuffle the packing too between epochs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@winglian Thanks for the comment! Very interesting.
while this is partially mitigated with larger batch sizes
I didn't quite understand how this is mitigated with larger batch sizes - would the samples still be in the same order within that packed training sample
Options here would be:
- Dataset could potentially perform that shuffle given the token separation between packed samples
- The other option is to do on the fly sample packing but the complication will be with managing uneven number of batches for each rank.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting point, could we shuffle the packed samples on-the-fly using the integer mask during train time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can potentially pack samples into a fixed batch, and then when increasing gradient accumulation steps, those would mix the various fixed batches into a single step. I think complete offline packing probably isn't necessary as I believe it's pretty efficient to pack during training. The biggest gains to offline is in tokenization anyways.
rfc_data.md
Outdated
encoded_prompt = self._tokenizer.encode(prompt) | ||
labels = self._tokenizer.encode(sample["output"]) | ||
# Collate | ||
collated_prompt, collated_labels = self._collator(encoded_prompt, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you chose to use the collator within the dataset? Should this not be the responsibility of the DataLoader since the collate_fn is part of the DataLoader's signature? https://pytorch.org/docs/stable/data.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to @winglian 's question, what does the collator do here as it either gets a single sample or a packed sample - that is an individual sample in both scenarios
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The primary reason is so all things data related can be instantiated / configured together from the config
dataset:
_component_: torchtune.datasets.build_dataset
template: my_template
collator: my_collator
...
If we allowed users to configure the collate_fn in the DataLoader directly, it would just be separate in the config
dataset:
_component_: torchtune.datasets.build_dataset
template: my_template
collator:
_component_: torchtune.datasets.collators.my_collator
which is a little awkward. Although coupling the collator with the dataset and not using collate_fn
may bring some confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would there by a dataloader component too or is it just dataset? If it is dataloader component, collate fn should be part of it as it is an argument of the oss dataloader
Datasets we should support are: | ||
- HuggingFace Hub, specified with the standard username/dataset_name. A common specification is the data_files parameter, which we could support by letting users continue the HF path: username/dataset_name/files/user/wants. Ex: allenai/c4/en/c4-train.0000*-of-01024.json.gz. | ||
- Local files, specified with a simple path string | ||
- Remote files via HTTPS, specified with a URL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how you feel about this library, https://pypi.org/project/smart-open/, but it provides a clean way to handle various formats, especially those commonly seen in more enterprise environments such as s3://, gs:// and hdfs://
two approaches to sample packing: | ||
* Offline bin-packing: take a dataset and iterate through the sample lengths. Use | ||
one of the algorithms to pack the dataset and either cache it or upload to HF hub. | ||
Axolotl uses the first-fit-decreasing algorithm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also a "hybrid" approach that axolotl employs for large pretraining datasets whereby Axolotl uses IterableDatasets. In this case, it sets a large buffer window and packs all of the samples in that window. You still get pretty high packing efficiencies and the time spent packing is still relatively fast compared to finetuning. One optimization I've not tried is to ensure that the packing is handled async on CPU while training is happening. This would add negligible overhead to online packing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One optimization I've not tried is to ensure that the packing is handled async on CPU while training is happening.
@winglian Can this be achieved by just setting num_workers=1 (or more) in the oss dataloader? [Setting it to more than 1 requires configuring sharding correctly for iterable datasets]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't dug into how exactly it gets handled with multiple workers in that case.
samples are uncorrelated. A more prudent approach would be to use an integer mask | ||
that labels which sample is which to prevent cross-attending. This is something |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for simplicity, Axolotl uses integer masks now https:/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/utils/collators.py#L171-L173 as this allows axolotl to simplify the computation of cu_seqlens across various model architectures. However, you can also pass position_ids and achieve the same result by parsing position_ids. https:/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py#L99-L155
rfc_data.md
Outdated
dataset: | ||
_component_: torchtune.datasets.build_dataset | ||
source: tatsu-lab/alpaca | ||
tokenizer: # TBD how to handle tokenizers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is tokenizer specific to the dataset? I don't feel there is a use case for a given model where we would need to specify different tokenizers across each dataset.
we make builders general enough that users can easily specify multiple datasets | ||
in a config without using nested components? | ||
- Should we use the collate_fn keyword in DataLoader for collators, or keep it coupled with TuneDataset? | ||
- Should we perform sample packing offline or do an online ad-hoc approach? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
even for relatively large datasets, offline sample packing is pretty optimized when using numpy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For offline sample packing, are we thinking of a separate script -- if that is the case, as discussed above there could preprocessing associated with it and so configuration of that script needs to be figured out (maybe the dataset part of config is for the offline script). But this would complicate it for folks who don't want packing.
So if we don't do offline script but do it before the training starts then would it be done in dataloader workers or before the training loop starts? Dataloader workers can provide compute speed up through parallelism (if we want to cut down on twfb) but it would require a switch to iterable dataset.
@winglian Here are you referring to full dataset offline preprocessing? Or is it the iterable style approach. If it is the former, curious to know how much time it took for say alpaca dataset to complete the offline sample packing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I flip-flopped my words and meant to say online sample packing is optimized when using numpy. I ran a quick test on the tatsu-lab/alpaca dataset and it took 6 seconds to pack everything (obviously this will vary based on the various packing configurations like total length, etc.
rfc_data.md
Outdated
|
||
# Instead, specify a builder | ||
dataset: | ||
_component_: torchtune.datasets.build_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, the intention here is that build_dataset
will know which components to combine based on the various parameters defined below?
raw text corpus. The awesome part is that you can combine multiple datasets all | ||
from the config (I think) and specify their template. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious to learn whether Axolotl supports a separate prompt format for each dataset? Also how are the datasets combined - are the samples in a batch interleaved or is it from one dataset?
|
||
## TuneDataset | ||
The general flow of loading a dataset from data file to tokenized prompt and label consists of: | ||
- (optional) Packing the dataset offline and caching it / saving it for use in current and future runs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@winglian Thanks for the comment! Very interesting.
while this is partially mitigated with larger batch sizes
I didn't quite understand how this is mitigated with larger batch sizes - would the samples still be in the same order within that packed training sample
Options here would be:
- Dataset could potentially perform that shuffle given the token separation between packed samples
- The other option is to do on the fly sample packing but the complication will be with managing uneven number of batches for each rank.
# Get a row | ||
sample = self._data[index] | ||
# Apply any transforms | ||
sample = self._transform(sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the transform operate on both individual sample as well as packed sample? Same question I guess for template formatting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm that's a great point, I hadn't thought about how packed samples would be templated and transformed... I supposed the packing might need to be done after all the preprocessing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think packing should come after all preprocessing
rfc_data.md
Outdated
encoded_prompt = self._tokenizer.encode(prompt) | ||
labels = self._tokenizer.encode(sample["output"]) | ||
# Collate | ||
collated_prompt, collated_labels = self._collator(encoded_prompt, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to @winglian 's question, what does the collator do here as it either gets a single sample or a packed sample - that is an individual sample in both scenarios
- HuggingFace Hub, specified with the standard username/dataset_name. A common specification is the data_files parameter, which we could support by letting users continue the HF path: username/dataset_name/files/user/wants. Ex: allenai/c4/en/c4-train.0000*-of-01024.json.gz. | ||
- Local files, specified with a simple path string | ||
- Remote files via HTTPS, specified with a URL | ||
All should be readily supported by using HF’s load_dataset() API, including support |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the suggestion that user uses HF's load_dataset to setup dataloading from all the above sources (including s3, gs, azure etc)?
- Remote files via HTTPS, specified with a URL | ||
All should be readily supported by using HF’s load_dataset() API, including support | ||
for JSON, JSONL, text, arrow, parquet, CSV, etc. We can provide a convenient utility | ||
get_data or similar that handles loading the data from the right location given |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious to learn what will the utility do on top of what load_dataset does - is it calling load_dataset vs load_from_disk based on the path string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HF's load_dataset is really only for remote dataset without any handling for local/on-disk data. In axolotl we handle this by first checking that the dataset path is a huggingface dataset (without actually pulling down the data) and use that as part of the flow control later .
Links: | ||
- Load_dataset: https://huggingface.co/docs/datasets/v1.12.0/loading.html | ||
- Cloud storage: https://huggingface.co/docs/datasets/en/filesystems | ||
- Streaming: https://huggingface.co/docs/datasets/en/stream |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
processing datasets during training leads to VRAM memory leaks.
Curious to know if this occured for text only datasets too or only in multimodal datasets?
# Get a row | ||
sample = self._data[index] | ||
# Apply any transforms | ||
sample = self._transform(sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think packing should come after all preprocessing
rfc_data.md
Outdated
encoded_prompt = self._tokenizer.encode(prompt) | ||
labels = self._tokenizer.encode(sample["output"]) | ||
# Collate | ||
collated_prompt, collated_labels = self._collator(encoded_prompt, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would there by a dataloader component too or is it just dataset? If it is dataloader component, collate fn should be part of it as it is an argument of the oss dataloader
method that does this as the transform here. Another example is llama recipes’ grammar | ||
dataset which requires splitting and processing the strings directly before templating. | ||
|
||
### Template |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think chat template provided by HF is based on that particular model/tokenizer (https://huggingface.co/docs/transformers/main/en/chat_templating) - tokenizer.apply_chat_template method
two approaches to sample packing: | ||
* Offline bin-packing: take a dataset and iterate through the sample lengths. Use | ||
one of the algorithms to pack the dataset and either cache it or upload to HF hub. | ||
Axolotl uses the first-fit-decreasing algorithm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One optimization I've not tried is to ensure that the packing is handled async on CPU while training is happening.
@winglian Can this be achieved by just setting num_workers=1 (or more) in the oss dataloader? [Setting it to more than 1 requires configuring sharding correctly for iterable datasets]
However, packing results in faster training since the model can see more samples | ||
at a time, but just requires some additional processing of the dataset. There are | ||
two approaches to sample packing: | ||
* Offline bin-packing: take a dataset and iterate through the sample lengths. Use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since llama recipes is mentioned below, I am curious how do we plan to handle scenarios where adding another sample to the pack will exceed the max seq length -- will we use that sample in the next packed sample or will that sample be truncated and just packed to max seq length. llama-recipes, mentioned below, takes the latter approach
rfc_data.md
Outdated
def samsum_dataset(tokenizer: Tokenizer) -> TuneDataset: | ||
return TuneDataset( | ||
source="samsum", | ||
template=SummarizeTemplate(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is template taking an object - I also see string used below
in a config without using nested components? | ||
- Should we use the collate_fn keyword in DataLoader for collators, or keep it coupled with TuneDataset? | ||
- Should we perform sample packing offline or do an online ad-hoc approach? | ||
- Are there other abstractions we need to consider, or abstractions we should simplify? (i.e., just using str instead of a PromptTemplate class) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Certainly need a PromptTemplate class I think (even for the alpaca example as you had called out as it has two prompts). Can plain string also be supported for simple usecases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, we should support both
in a config if nested configs / recursive instantiation are not enabled? How can | ||
we make builders general enough that users can easily specify multiple datasets | ||
in a config without using nested components? | ||
- Should we use the collate_fn keyword in DataLoader for collators, or keep it coupled with TuneDataset? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like there could be another "collate fn" for attention mask? For this question specifically, any pointers to how dataloader config looks like today in torchtune?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently not configured at all - we just use the padded_collate
as default. We'll have to show on our example configs how this can be changed if needed.
we make builders general enough that users can easily specify multiple datasets | ||
in a config without using nested components? | ||
- Should we use the collate_fn keyword in DataLoader for collators, or keep it coupled with TuneDataset? | ||
- Should we perform sample packing offline or do an online ad-hoc approach? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For offline sample packing, are we thinking of a separate script -- if that is the case, as discussed above there could preprocessing associated with it and so configuration of that script needs to be figured out (maybe the dataset part of config is for the offline script). But this would complicate it for folks who don't want packing.
So if we don't do offline script but do it before the training starts then would it be done in dataloader workers or before the training loop starts? Dataloader workers can provide compute speed up through parallelism (if we want to cut down on twfb) but it would require a switch to iterable dataset.
@winglian Here are you referring to full dataset offline preprocessing? Or is it the iterable style approach. If it is the former, curious to know how much time it took for say alpaca dataset to complete the offline sample packing
) | ||
``` | ||
|
||
### Open questions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I) One feature that I recall @laurencer had wished for in this area is -- how to verify that the transform+prompt+tokenizer are doing the right thing. Any misconfiguration here would result in subtle bugs that impact model performance? Metrics around this could be part of the answer - is sample packing helping, distribution of sample token size, is padding bloating up the input/label sizes, etc.
II) Also error message/propagation - if one sample didn't have the expected data/format (maybe transform fails) - can that particular sample be identified (sample number be printed) so that user can debug that
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/493
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f27a685 with merge base 7fe306c (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Datasets we should support are: | ||
- HuggingFace Hub, specified with the standard username/dataset_name. A common specification is the data_files parameter, which we could support by letting users continue the HF path: username/dataset_name/files/user/wants. Ex: allenai/c4/en/c4-train.0000*-of-01024.json.gz. | ||
- Local files, specified with a simple path string | ||
- Remote files via HTTPS, specified with a URL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we considering this a P0? If so, why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be supported out of the box by load_dataset
: https://huggingface.co/learn/nlp-course/en/chapter5/2#loading-a-remote-dataset
Providing a great user experience for using common datasets or plugging in custom datasets is critical to the success of TorchTune, or any fine-tuning library. Here, I propose an API design and walk through the different levels of abstractions we'll need. There are still open questions around how this should interplay with the config system that I would love feedback on.
The RFC is in changed files as a markdown, pasting it here for an easy read, but you can comment on a specific line in the file directly.
Datasets in TorchTune
Motivation
When fine-tuning LLMs, there are three main areas where users can influence the
end result and final performance: the model architecture, fine-tuning hyperparameters,
and the dataset. For open source users, many who are hobbyists and hackers, modifying
the architecture itself or experimenting with hyperparameters is unfeasible either
due to lack of expertise or lack of resources. Most often they will just use the
best pre-trained model for their use case given the plethora of options with some
default hyperparameters. The most common user journey for fine-tuning an LLM is
to quickly bootstrap a pre-trained model with their custom dataset(s). This means
that data is OSS users’ primary means of controlling the model. It then becomes
imperative that we curate the smoothest user experience with world-class API design
for plugging in custom datasets to fine-tune with TorchTune. This document will
overview the abstractions needed to support custom datasets in TorchTune, a high
level API design, and what the user journey with these components looks like in
TorchTune.
“However, the real challenge lies in preparing the data. A massive wiki of product
documentation, a thousand PDFs of your processes, or even a bustling support forum
with countless topics - they all amount to nothing if you don't have your data in
the right format. Projects like Dolly and Orca have shown us how enriching data
with context or system prompts can significantly improve the final model's quality
[...] Personally, I mostly utilize the #instruction, #input, #output format for
most of my fine-tuning tasks. So, shaping your data in the correct format is, without
a doubt, the most difficult and time-consuming step when creating a Language Learning
Model (LLM) for your company's documentation, processes, support, sales, and so
forth.” - user on r/LocalLLaMA
Existing OSS Solutions
HuggingFace
Links:
HuggingFace does a great job provide an incredible breadth of utilities that allow
users to load in any local file, remote huggingface dataset, or dataset on cloud
storage and provides basic preprocessing functionality (including shard, map, multiprocess,
interleave, concatenate, filter, split, etc). These datasets can also be streamed
so it is not all downloaded at once. If a user wants to use a dataset present on
the huggingface hub, they should be able to leverage the full functionality of this
ecosystem.
Axolotl
Link: https:/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset
One aspect Axolotl does really well is maintaining a suite of various prompt templates
that will automatically tokenize a data source into the template, with some configurability
from the YAML config. This covers everything from instruction to conversations to
raw text corpus. The awesome part is that you can combine multiple datasets all
from the config (I think) and specify their template.
TuneDataset
The general flow of loading a dataset from data file to tokenized prompt and label consists of:
Since each step uses a user-provided component, every step is fully customizable
using standard components in the library or custom components from the user, provided
that it follows a particular interface.
Source
Datasets we should support are:
All should be readily supported by using HF’s load_dataset() API, including support
for JSON, JSONL, text, arrow, parquet, CSV, etc. We can provide a convenient utility
get_data or similar that handles loading the data from the right location given
a simple path string.
Transform
This is an optional parameter that users can provide to do any preprocessing on
their data before templating. The most immediate use case would be to convert from
one prompt template to another. For example, in TorchTune’s SlimOrcaDataset implementation,
we convert from ShareGPT Conversation template to llama2 chat. We can pass in a
method that does this as the transform here. Another example is llama recipes’ grammar
dataset which requires splitting and processing the strings directly before templating.
Template
This class should handle massaging the input data columns/fields into a predefined
prompt template. We will support the most common ones: instruct, chat, sharegpt,
etc. This is akin to Axolotl’s Prompters.
One challenge is mapping the input data columns into the correct fields for the
prompt template, given that the input dataset could have different columns, formats,
etc. We can allow users to provide the column to field mapping for the prompt template.
We could also just make the prompt templates plain strings and use the native string
format method. However, for cases like Alpaca where we want to handle with / without
input, we need a bit more functionality. Using a class is also more extensible.
Collator
While transforms process the sample BEFORE templating and tokenization, collators
include any data utils that process the sample AFTER tokenization. The primary util
is padding to a max length, which we can repurpose utils.padded_collate for. We
can also make the train_on_input functionality a collator.
Open question: should we ditch using the collate_fn kwarg in DataLoader in favor
of coupling the collator with TuneDataset? What’s the tradeoff here?
Sample packing
Packing involves stuffing multiple data samples in the input upto the max sequence
length to make full use of the context window. The algorithms to achieve this are
the same ones used to solve the classic bin packing problem. Unfortunately, these
algorithms require knowledge of the full distribution of sample lengths, meaning
we need to iterate through the entire dataset before we can begin sample packing.
However, packing results in faster training since the model can see more samples
at a time, but just requires some additional processing of the dataset. There are
two approaches to sample packing:
one of the algorithms to pack the dataset and either cache it or upload to HF hub.
Axolotl uses the first-fit-decreasing algorithm.
significant storage space. But overall fine-tuning is faster.
We don’t see all the sample lengths at once so we cannot use a bin-packing algorithm;
instead greedily pack the context window as we iterate. This is done by llama recipes.
The approach we take could dictate the design of sample packing API. Options are:
a separate offline script, a boolean flag in the TuneDataset class, an entirely
different dataset class.
On masking for packed samples: Because we have multiple samples in the same input,
we need to tell the model to not attend to other irrelevant samples. Generally,
the guidance in OSS has been that the EOS token between samples is sufficient because
samples are uncorrelated. A more prudent approach would be to use an integer mask
that labels which sample is which to prevent cross-attending. This is something
we should support, and may involve a custom collator to handle this. More discussion
can be found here: meta-llama/llama-recipes#341
Configuring datasets in configs
As proposed, TuneDataset requires other higher level components as parameters, such
as PromptTemplate, Tokenizer, Callables, each with their own keyword arguments.
This is problematic because we purposely restricted nested components in the config
system to avoid meta-programming via yaml file. In other words, you cannot configure
a dataset directly with TuneDataset from the config. There is an alternative approach
that still enable configurable datasets without compromising on the nested components
principle.
Builders
Common datasets will have builder functions with flat params that can be easily
specified with one component in the config file. We can also provide a builder function
for custom datasets with limited functionality. This may require some registry object
that contains the mapping from string to common prompt templates or collators and
an associated getter function. While this is not ideal and something that the current
config.instantiate API was originally trying to bypass, we can keep it fairly contained
to just prompt templates and basic collators/transforms, for example.
Options that require more involved customization, such as custom transforms, will
require a user to create custom builder functions that they can then specify via
the config. Since transforms will typically require the user to write code anyway,
adding the responsibility of creating a builder is not too burdensome. I contend
that this extra user burden is worth intentionally restricting nested components
to prevent intrusive configs.
Supported Datasets
Here, we use our flagship supported datasets as examples to exhibit the versatility
and extensibility of the TuneDataset API design.
Alpaca
Instruct tasks. Good example of straightforward dataset that doesn’t require transforms,
column mapping, template conversion.
SlimOrca
Conversational/chat tasks. Good example of utilizing the transform for template
conversion.
Samsum
Summarization tasks.
Grammar
Primarily for grammatical error correction tasks. In llama recipes, two datasets
are used. For now, we will use one dataset until the multi-dataset UX is more thought
out.
Open questions
in a config if nested configs / recursive instantiation are not enabled? How can
we make builders general enough that users can easily specify multiple datasets
in a config without using nested components?