Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Datasets and data utils in TorchTune #493

Closed
wants to merge 3 commits into from
Closed

Conversation

RdoubleA
Copy link
Contributor

Providing a great user experience for using common datasets or plugging in custom datasets is critical to the success of TorchTune, or any fine-tuning library. Here, I propose an API design and walk through the different levels of abstractions we'll need. There are still open questions around how this should interplay with the config system that I would love feedback on.

The RFC is in changed files as a markdown, pasting it here for an easy read, but you can comment on a specific line in the file directly.

Datasets in TorchTune

Motivation

When fine-tuning LLMs, there are three main areas where users can influence the
end result and final performance: the model architecture, fine-tuning hyperparameters,
and the dataset. For open source users, many who are hobbyists and hackers, modifying
the architecture itself or experimenting with hyperparameters is unfeasible either
due to lack of expertise or lack of resources. Most often they will just use the
best pre-trained model for their use case given the plethora of options with some
default hyperparameters. The most common user journey for fine-tuning an LLM is
to quickly bootstrap a pre-trained model with their custom dataset(s). This means
that data is OSS users’ primary means of controlling the model. It then becomes
imperative that we curate the smoothest user experience with world-class API design
for plugging in custom datasets to fine-tune with TorchTune. This document will
overview the abstractions needed to support custom datasets in TorchTune, a high
level API design, and what the user journey with these components looks like in
TorchTune.

“However, the real challenge lies in preparing the data. A massive wiki of product
documentation, a thousand PDFs of your processes, or even a bustling support forum
with countless topics - they all amount to nothing if you don't have your data in
the right format. Projects like Dolly and Orca have shown us how enriching data
with context or system prompts can significantly improve the final model's quality
[...] Personally, I mostly utilize the #instruction, #input, #output format for
most of my fine-tuning tasks. So, shaping your data in the correct format is, without
a doubt, the most difficult and time-consuming step when creating a Language Learning
Model (LLM) for your company's documentation, processes, support, sales, and so
forth.” - user on r/LocalLLaMA

Existing OSS Solutions

HuggingFace

Links:

HuggingFace does a great job provide an incredible breadth of utilities that allow
users to load in any local file, remote huggingface dataset, or dataset on cloud
storage and provides basic preprocessing functionality (including shard, map, multiprocess,
interleave, concatenate, filter, split, etc). These datasets can also be streamed
so it is not all downloaded at once. If a user wants to use a dataset present on
the huggingface hub, they should be able to leverage the full functionality of this
ecosystem.

Axolotl

Link: https:/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset

One aspect Axolotl does really well is maintaining a suite of various prompt templates
that will automatically tokenize a data source into the template, with some configurability
from the YAML config. This covers everything from instruction to conversations to
raw text corpus. The awesome part is that you can combine multiple datasets all
from the config (I think) and specify their template.

TuneDataset

The general flow of loading a dataset from data file to tokenized prompt and label consists of:

  • (optional) Packing the dataset offline and caching it / saving it for use in current and future runs
  • Get a single sample
  • Apply user transform to any of the columns - this could also be converting from one template to another, as is the case for SlimOrca: sharegpt -> llama2 chat
  • Format into provided template using PromptTemplate’s methods
  • Tokenize with provided tokenizer
  • Collate tokenized output - padding, modify masks for packing, etc
    Since each step uses a user-provided component, every step is fully customizable
    using standard components in the library or custom components from the user, provided
    that it follows a particular interface.
from torch.utils.data import Dataset

class TuneDataset(Dataset):
    def __init__(
        self,
        source: str,
        column_map: Optional[Dict[str, str]],
        transform: Optional[Callable],
        template: Union[PromptTemplate, str],
        tokenizer: Tokenizer,
        collator: Callable,
        packing: bool = False,
    ) -> None:
        # Set all attributes
        ...
        self._data = get_data(source)
        if packing:
            self._data = sample_packing(self._data)

    def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
        # Get a row
        sample = self._data[index]
        # Apply any transforms
        sample = self._transform(sample)
        # Format into template
        prompt = self._template.format(sample, self._column_map)
        # Tokenize
        encoded_prompt = self._tokenizer.encode(prompt)
        labels = self._tokenizer.encode(sample["output"])
        # Collate
        collated_prompt, collated_labels = self._collator(encoded_prompt, labels)

        return collated_prompt, collated_labels

Source

Datasets we should support are:

  • HuggingFace Hub, specified with the standard username/dataset_name. A common specification is the data_files parameter, which we could support by letting users continue the HF path: username/dataset_name/files/user/wants. Ex: allenai/c4/en/c4-train.0000*-of-01024.json.gz.
  • Local files, specified with a simple path string
  • Remote files via HTTPS, specified with a URL
    All should be readily supported by using HF’s load_dataset() API, including support
    for JSON, JSONL, text, arrow, parquet, CSV, etc. We can provide a convenient utility
    get_data or similar that handles loading the data from the right location given
    a simple path string.

Transform

This is an optional parameter that users can provide to do any preprocessing on
their data before templating. The most immediate use case would be to convert from
one prompt template to another. For example, in TorchTune’s SlimOrcaDataset implementation,
we convert from ShareGPT Conversation template to llama2 chat. We can pass in a
method that does this as the transform here. Another example is llama recipes’ grammar
dataset which requires splitting and processing the strings directly before templating.

Template

This class should handle massaging the input data columns/fields into a predefined
prompt template. We will support the most common ones: instruct, chat, sharegpt,
etc. This is akin to Axolotl’s Prompters.

class PromptTemplate:
    system = "This is an example system prompt: {input} {output}"
    def format(self, sample: Dict[str, str], column_map: Dict[str, str]) -> str:
        return self.system.format(input=sample[column_map["input"]], output=sample[column_map["output"]])

One challenge is mapping the input data columns into the correct fields for the
prompt template, given that the input dataset could have different columns, formats,
etc. We can allow users to provide the column to field mapping for the prompt template.

We could also just make the prompt templates plain strings and use the native string
format method. However, for cases like Alpaca where we want to handle with / without
input, we need a bit more functionality. Using a class is also more extensible.

Collator

While transforms process the sample BEFORE templating and tokenization, collators
include any data utils that process the sample AFTER tokenization. The primary util
is padding to a max length, which we can repurpose utils.padded_collate for. We
can also make the train_on_input functionality a collator.

Open question: should we ditch using the collate_fn kwarg in DataLoader in favor
of coupling the collator with TuneDataset? What’s the tradeoff here?

Sample packing

Packing involves stuffing multiple data samples in the input upto the max sequence
length to make full use of the context window. The algorithms to achieve this are
the same ones used to solve the classic bin packing problem. Unfortunately, these
algorithms require knowledge of the full distribution of sample lengths, meaning
we need to iterate through the entire dataset before we can begin sample packing.
However, packing results in faster training since the model can see more samples
at a time, but just requires some additional processing of the dataset. There are
two approaches to sample packing:

  • Offline bin-packing: take a dataset and iterate through the sample lengths. Use
    one of the algorithms to pack the dataset and either cache it or upload to HF hub.
    Axolotl uses the first-fit-decreasing algorithm.
    • Tradeoff: TTFB is much longer, need to figure out caching which may require
      significant storage space. But overall fine-tuning is faster.
  • Online greedy: instead of packing offline, do it as we iterate through the dataset.
    We don’t see all the sample lengths at once so we cannot use a bin-packing algorithm;
    instead greedily pack the context window as we iterate. This is done by llama recipes.
    • Tradeoff: Faster TTFB, no caching of entire dataset. Slower fine-tuning.

The approach we take could dictate the design of sample packing API. Options are:
a separate offline script, a boolean flag in the TuneDataset class, an entirely
different dataset class.

On masking for packed samples: Because we have multiple samples in the same input,
we need to tell the model to not attend to other irrelevant samples. Generally,
the guidance in OSS has been that the EOS token between samples is sufficient because
samples are uncorrelated. A more prudent approach would be to use an integer mask
that labels which sample is which to prevent cross-attending. This is something
we should support, and may involve a custom collator to handle this. More discussion
can be found here: meta-llama/llama-recipes#341

Configuring datasets in configs

As proposed, TuneDataset requires other higher level components as parameters, such
as PromptTemplate, Tokenizer, Callables, each with their own keyword arguments.
This is problematic because we purposely restricted nested components in the config
system to avoid meta-programming via yaml file. In other words, you cannot configure
a dataset directly with TuneDataset from the config. There is an alternative approach
that still enable configurable datasets without compromising on the nested components
principle.

Builders

Common datasets will have builder functions with flat params that can be easily
specified with one component in the config file. We can also provide a builder function
for custom datasets with limited functionality. This may require some registry object
that contains the mapping from string to common prompt templates or collators and
an associated getter function. While this is not ideal and something that the current
config.instantiate API was originally trying to bypass, we can keep it fairly contained
to just prompt templates and basic collators/transforms, for example.

def build_dataset(
    source: str,
    column_map: Dict[str, str],
    tokenizer: Tokenizer,  # do we need a mapping for all tokenizers or rely on partial instantiation?
    template: str,  # Choose from common templates in library
    pad: bool = True,
    packing: bool = False,
) -> TuneDataset

# In the yaml config - we cannot do nested components
dataset:
  _component_: torchtune.datasets.TuneDataset
  template:
    _component_: torchtune.datasets.prompt_templates.AlpacaInstructTemplate
    ...

# Instead, specify a builder
dataset:
  _component_: torchtune.datasets.build_dataset
  source: tatsu-lab/alpaca
  tokenizer: # TBD how to handle tokenizers
  template: instruct
  pad: True
  packing: True

Options that require more involved customization, such as custom transforms, will
require a user to create custom builder functions that they can then specify via
the config. Since transforms will typically require the user to write code anyway,
adding the responsibility of creating a builder is not too burdensome. I contend
that this extra user burden is worth intentionally restricting nested components
to prevent intrusive configs.

def my_custom_dataset(
    my_param: str,
) -> TuneDataset:
    # Logic to create custom dataset here with exact components
    ...

dataset:
  _component_: torchtune.datasets.my_custom_dataset
  my_param: hello

Supported Datasets

Here, we use our flagship supported datasets as examples to exhibit the versatility
and extensibility of the TuneDataset API design.

Alpaca

Instruct tasks. Good example of straightforward dataset that doesn’t require transforms,
column mapping, template conversion.

def alpaca_dataset(tokenizer: Tokenizer, use_clean: bool = False) -> TuneDataset:
    return TuneDataset(
        source="yahma/alpaca-cleaned" if use_clean else "tatsu-lab/alpaca",
        template=AlpacaInstructTemplate(),
        tokenizer=tokenizer,
        collator=pad_and_train_on_input,
    )

SlimOrca

Conversational/chat tasks. Good example of utilizing the transform for template
conversion.

  • Source: Open-Orca/SlimOrca-Dedup
  • Source Template: sharegpt conversation
  • Target Template: In TorchTune, we convert to llama2 chat using _generate_prompt_label
  • Collator: pad
def slim_orca_dataset(tokenizer: Tokenizer) -> TuneDataset:
    return TuneDataset(
        source="Open-Orca/SlimOrca-Dedup",
        transform=convert_sharegpt_to_llama2_chat,
        template=Llama2ChatTemplate(),
        tokenizer=tokenizer,
        collator=pad,
    )

Samsum

Summarization tasks.

  • Source: samsum
  • Source Template: dialogue and summary columns
  • Target Template: Summary - similar to SummarizeTLDRPrompter in Axolotl. Also see the version in llama recipes
  • Collator: pad
def samsum_dataset(tokenizer: Tokenizer) -> TuneDataset:
    return TuneDataset(
        source="samsum",
        template=SummarizeTemplate(),
        tokenizer=tokenizer,
        collator=pad,
    )

Grammar

Primarily for grammatical error correction tasks. In llama recipes, two datasets
are used. For now, we will use one dataset until the multi-dataset UX is more thought
out.

def grammar_dataset(tokenizer: Tokenizer) -> TuneDataset:
    return TuneDataset(
        source="liweili/c4_200m",
        column_map={"sentence": "input", "correction": "output"},
        template="Correct this to standard English: {sentence}\n---\nCorrected: {correction}",
        tokenizer=tokenizer,
        collator=pad,
    )

Open questions

  • TuneDataset has high level components as parameters. How can users specify this
    in a config if nested configs / recursive instantiation are not enabled? How can
    we make builders general enough that users can easily specify multiple datasets
    in a config without using nested components?
  • Should we use the collate_fn keyword in DataLoader for collators, or keep it coupled with TuneDataset?
  • Should we perform sample packing offline or do an online ad-hoc approach?
  • Are there other abstractions we need to consider, or abstractions we should simplify? (i.e., just using str instead of a PromptTemplate class)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 13, 2024
Copy link

netlify bot commented Mar 13, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit f27a685
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65f84dd880dda40008183b71
😎 Deploy Preview https://deploy-preview-493--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

Links:
- Load_dataset: https://huggingface.co/docs/datasets/v1.12.0/loading.html
- Cloud storage: https://huggingface.co/docs/datasets/en/filesystems
- Streaming: https://huggingface.co/docs/datasets/en/stream
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https:/mosaicml/streaming is also a high quality streaming dataset library that also handles dataset mixing ratios. The built-in streaming in HF's dataset library is not very robust and efficiently processing IterableDatasets at runtime is a consideration that should be addressed. For example, I've seen that the processing datasets during training leads to VRAM memory leaks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now we've made streaming datasets a P1, or at least a quick follow up after the MVP launch. How important would you say it is to have support for both streaming and iterable datasets out of the gate?

processing datasets during training leads to VRAM memory leaks.

We should definitely be wary of this especially for sample packing and any heavy transforms

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

processing datasets during training leads to VRAM memory leaks.

Curious to know if this occured for text only datasets too or only in multimodal datasets?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't experimented much with multimodal datasets yet.

Comment on lines +54 to +55
raw text corpus. The awesome part is that you can combine multiple datasets all
from the config (I think) and specify their template.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is generally correct for the most common dataset structures. There is enough coverage for dataset structure -> chosen prompt format that it is possible to combine several datasets at once.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious to learn whether Axolotl supports a separate prompt format for each dataset? Also how are the datasets combined - are the samples in a batch interleaved or is it from one dataset?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each dataset can have it's own prompt format and after tokenization, we concatenate and shuffle the tokenized datasets.


## TuneDataset
The general flow of loading a dataset from data file to tokenized prompt and label consists of:
- (optional) Packing the dataset offline and caching it / saving it for use in current and future runs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline packing leads to fixed "shuffling" for a given batch between epochs, while this is partially mitigated with larger batch sizes, I've found most users had asked if we could shuffle the packing too between epochs.

Copy link
Contributor

@gokulavasan gokulavasan Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@winglian Thanks for the comment! Very interesting.

while this is partially mitigated with larger batch sizes

I didn't quite understand how this is mitigated with larger batch sizes - would the samples still be in the same order within that packed training sample

Options here would be:

  1. Dataset could potentially perform that shuffle given the token separation between packed samples
  2. The other option is to do on the fly sample packing but the complication will be with managing uneven number of batches for each rank.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting point, could we shuffle the packed samples on-the-fly using the integer mask during train time?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can potentially pack samples into a fixed batch, and then when increasing gradient accumulation steps, those would mix the various fixed batches into a single step. I think complete offline packing probably isn't necessary as I believe it's pretty efficient to pack during training. The biggest gains to offline is in tokenization anyways.

rfc_data.md Outdated
encoded_prompt = self._tokenizer.encode(prompt)
labels = self._tokenizer.encode(sample["output"])
# Collate
collated_prompt, collated_labels = self._collator(encoded_prompt, labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you chose to use the collator within the dataset? Should this not be the responsibility of the DataLoader since the collate_fn is part of the DataLoader's signature? https://pytorch.org/docs/stable/data.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to @winglian 's question, what does the collator do here as it either gets a single sample or a packed sample - that is an individual sample in both scenarios

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The primary reason is so all things data related can be instantiated / configured together from the config

dataset:
  _component_: torchtune.datasets.build_dataset
  template: my_template
  collator: my_collator
  ...

If we allowed users to configure the collate_fn in the DataLoader directly, it would just be separate in the config

dataset:
  _component_: torchtune.datasets.build_dataset
  template: my_template

collator:
  _component_: torchtune.datasets.collators.my_collator

which is a little awkward. Although coupling the collator with the dataset and not using collate_fn may bring some confusion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there by a dataloader component too or is it just dataset? If it is dataloader component, collate fn should be part of it as it is an argument of the oss dataloader

Datasets we should support are:
- HuggingFace Hub, specified with the standard username/dataset_name. A common specification is the data_files parameter, which we could support by letting users continue the HF path: username/dataset_name/files/user/wants. Ex: allenai/c4/en/c4-train.0000*-of-01024.json.gz.
- Local files, specified with a simple path string
- Remote files via HTTPS, specified with a URL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how you feel about this library, https://pypi.org/project/smart-open/, but it provides a clean way to handle various formats, especially those commonly seen in more enterprise environments such as s3://, gs:// and hdfs://

two approaches to sample packing:
* Offline bin-packing: take a dataset and iterate through the sample lengths. Use
one of the algorithms to pack the dataset and either cache it or upload to HF hub.
Axolotl uses the first-fit-decreasing algorithm.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also a "hybrid" approach that axolotl employs for large pretraining datasets whereby Axolotl uses IterableDatasets. In this case, it sets a large buffer window and packs all of the samples in that window. You still get pretty high packing efficiencies and the time spent packing is still relatively fast compared to finetuning. One optimization I've not tried is to ensure that the packing is handled async on CPU while training is happening. This would add negligible overhead to online packing.

Copy link
Contributor

@gokulavasan gokulavasan Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One optimization I've not tried is to ensure that the packing is handled async on CPU while training is happening.

@winglian Can this be achieved by just setting num_workers=1 (or more) in the oss dataloader? [Setting it to more than 1 requires configuring sharding correctly for iterable datasets]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't dug into how exactly it gets handled with multiple workers in that case.

Comment on lines +176 to +177
samples are uncorrelated. A more prudent approach would be to use an integer mask
that labels which sample is which to prevent cross-attending. This is something
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for simplicity, Axolotl uses integer masks now https:/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/utils/collators.py#L171-L173 as this allows axolotl to simplify the computation of cu_seqlens across various model architectures. However, you can also pass position_ids and achieve the same result by parsing position_ids. https:/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py#L99-L155

rfc_data.md Outdated
dataset:
_component_: torchtune.datasets.build_dataset
source: tatsu-lab/alpaca
tokenizer: # TBD how to handle tokenizers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is tokenizer specific to the dataset? I don't feel there is a use case for a given model where we would need to specify different tokenizers across each dataset.

we make builders general enough that users can easily specify multiple datasets
in a config without using nested components?
- Should we use the collate_fn keyword in DataLoader for collators, or keep it coupled with TuneDataset?
- Should we perform sample packing offline or do an online ad-hoc approach?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even for relatively large datasets, offline sample packing is pretty optimized when using numpy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For offline sample packing, are we thinking of a separate script -- if that is the case, as discussed above there could preprocessing associated with it and so configuration of that script needs to be figured out (maybe the dataset part of config is for the offline script). But this would complicate it for folks who don't want packing.

So if we don't do offline script but do it before the training starts then would it be done in dataloader workers or before the training loop starts? Dataloader workers can provide compute speed up through parallelism (if we want to cut down on twfb) but it would require a switch to iterable dataset.

@winglian Here are you referring to full dataset offline preprocessing? Or is it the iterable style approach. If it is the former, curious to know how much time it took for say alpaca dataset to complete the offline sample packing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I flip-flopped my words and meant to say online sample packing is optimized when using numpy. I ran a quick test on the tatsu-lab/alpaca dataset and it took 6 seconds to pack everything (obviously this will vary based on the various packing configurations like total length, etc.

rfc_data.md Outdated

# Instead, specify a builder
dataset:
_component_: torchtune.datasets.build_dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, the intention here is that build_dataset will know which components to combine based on the various parameters defined below?

Comment on lines +54 to +55
raw text corpus. The awesome part is that you can combine multiple datasets all
from the config (I think) and specify their template.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious to learn whether Axolotl supports a separate prompt format for each dataset? Also how are the datasets combined - are the samples in a batch interleaved or is it from one dataset?


## TuneDataset
The general flow of loading a dataset from data file to tokenized prompt and label consists of:
- (optional) Packing the dataset offline and caching it / saving it for use in current and future runs
Copy link
Contributor

@gokulavasan gokulavasan Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@winglian Thanks for the comment! Very interesting.

while this is partially mitigated with larger batch sizes

I didn't quite understand how this is mitigated with larger batch sizes - would the samples still be in the same order within that packed training sample

Options here would be:

  1. Dataset could potentially perform that shuffle given the token separation between packed samples
  2. The other option is to do on the fly sample packing but the complication will be with managing uneven number of batches for each rank.

# Get a row
sample = self._data[index]
# Apply any transforms
sample = self._transform(sample)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the transform operate on both individual sample as well as packed sample? Same question I guess for template formatting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm that's a great point, I hadn't thought about how packed samples would be templated and transformed... I supposed the packing might need to be done after all the preprocessing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think packing should come after all preprocessing

rfc_data.md Outdated
encoded_prompt = self._tokenizer.encode(prompt)
labels = self._tokenizer.encode(sample["output"])
# Collate
collated_prompt, collated_labels = self._collator(encoded_prompt, labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to @winglian 's question, what does the collator do here as it either gets a single sample or a packed sample - that is an individual sample in both scenarios

- HuggingFace Hub, specified with the standard username/dataset_name. A common specification is the data_files parameter, which we could support by letting users continue the HF path: username/dataset_name/files/user/wants. Ex: allenai/c4/en/c4-train.0000*-of-01024.json.gz.
- Local files, specified with a simple path string
- Remote files via HTTPS, specified with a URL
All should be readily supported by using HF’s load_dataset() API, including support
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the suggestion that user uses HF's load_dataset to setup dataloading from all the above sources (including s3, gs, azure etc)?

- Remote files via HTTPS, specified with a URL
All should be readily supported by using HF’s load_dataset() API, including support
for JSON, JSONL, text, arrow, parquet, CSV, etc. We can provide a convenient utility
get_data or similar that handles loading the data from the right location given
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to learn what will the utility do on top of what load_dataset does - is it calling load_dataset vs load_from_disk based on the path string?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HF's load_dataset is really only for remote dataset without any handling for local/on-disk data. In axolotl we handle this by first checking that the dataset path is a huggingface dataset (without actually pulling down the data) and use that as part of the flow control later .

Links:
- Load_dataset: https://huggingface.co/docs/datasets/v1.12.0/loading.html
- Cloud storage: https://huggingface.co/docs/datasets/en/filesystems
- Streaming: https://huggingface.co/docs/datasets/en/stream
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

processing datasets during training leads to VRAM memory leaks.

Curious to know if this occured for text only datasets too or only in multimodal datasets?

# Get a row
sample = self._data[index]
# Apply any transforms
sample = self._transform(sample)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think packing should come after all preprocessing

rfc_data.md Outdated
encoded_prompt = self._tokenizer.encode(prompt)
labels = self._tokenizer.encode(sample["output"])
# Collate
collated_prompt, collated_labels = self._collator(encoded_prompt, labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there by a dataloader component too or is it just dataset? If it is dataloader component, collate fn should be part of it as it is an argument of the oss dataloader

method that does this as the transform here. Another example is llama recipes’ grammar
dataset which requires splitting and processing the strings directly before templating.

### Template
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think chat template provided by HF is based on that particular model/tokenizer (https://huggingface.co/docs/transformers/main/en/chat_templating) - tokenizer.apply_chat_template method

two approaches to sample packing:
* Offline bin-packing: take a dataset and iterate through the sample lengths. Use
one of the algorithms to pack the dataset and either cache it or upload to HF hub.
Axolotl uses the first-fit-decreasing algorithm.
Copy link
Contributor

@gokulavasan gokulavasan Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One optimization I've not tried is to ensure that the packing is handled async on CPU while training is happening.

@winglian Can this be achieved by just setting num_workers=1 (or more) in the oss dataloader? [Setting it to more than 1 requires configuring sharding correctly for iterable datasets]

However, packing results in faster training since the model can see more samples
at a time, but just requires some additional processing of the dataset. There are
two approaches to sample packing:
* Offline bin-packing: take a dataset and iterate through the sample lengths. Use
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since llama recipes is mentioned below, I am curious how do we plan to handle scenarios where adding another sample to the pack will exceed the max seq length -- will we use that sample in the next packed sample or will that sample be truncated and just packed to max seq length. llama-recipes, mentioned below, takes the latter approach

rfc_data.md Outdated
def samsum_dataset(tokenizer: Tokenizer) -> TuneDataset:
return TuneDataset(
source="samsum",
template=SummarizeTemplate(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is template taking an object - I also see string used below

in a config without using nested components?
- Should we use the collate_fn keyword in DataLoader for collators, or keep it coupled with TuneDataset?
- Should we perform sample packing offline or do an online ad-hoc approach?
- Are there other abstractions we need to consider, or abstractions we should simplify? (i.e., just using str instead of a PromptTemplate class)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Certainly need a PromptTemplate class I think (even for the alpaca example as you had called out as it has two prompts). Can plain string also be supported for simple usecases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we should support both

in a config if nested configs / recursive instantiation are not enabled? How can
we make builders general enough that users can easily specify multiple datasets
in a config without using nested components?
- Should we use the collate_fn keyword in DataLoader for collators, or keep it coupled with TuneDataset?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there could be another "collate fn" for attention mask? For this question specifically, any pointers to how dataloader config looks like today in torchtune?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently not configured at all - we just use the padded_collate as default. We'll have to show on our example configs how this can be changed if needed.

we make builders general enough that users can easily specify multiple datasets
in a config without using nested components?
- Should we use the collate_fn keyword in DataLoader for collators, or keep it coupled with TuneDataset?
- Should we perform sample packing offline or do an online ad-hoc approach?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For offline sample packing, are we thinking of a separate script -- if that is the case, as discussed above there could preprocessing associated with it and so configuration of that script needs to be figured out (maybe the dataset part of config is for the offline script). But this would complicate it for folks who don't want packing.

So if we don't do offline script but do it before the training starts then would it be done in dataloader workers or before the training loop starts? Dataloader workers can provide compute speed up through parallelism (if we want to cut down on twfb) but it would require a switch to iterable dataset.

@winglian Here are you referring to full dataset offline preprocessing? Or is it the iterable style approach. If it is the former, curious to know how much time it took for say alpaca dataset to complete the offline sample packing

)
```

### Open questions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I) One feature that I recall @laurencer had wished for in this area is -- how to verify that the transform+prompt+tokenizer are doing the right thing. Any misconfiguration here would result in subtle bugs that impact model performance? Metrics around this could be part of the answer - is sample packing helping, distribution of sample token size, is padding bloating up the input/label sizes, etc.

II) Also error message/propagation - if one sample didn't have the expected data/format (maybe transform fails) - can that particular sample be identified (sample number be printed) so that user can debug that

Copy link

pytorch-bot bot commented Mar 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/493

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f27a685 with merge base 7fe306c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Datasets we should support are:
- HuggingFace Hub, specified with the standard username/dataset_name. A common specification is the data_files parameter, which we could support by letting users continue the HF path: username/dataset_name/files/user/wants. Ex: allenai/c4/en/c4-train.0000*-of-01024.json.gz.
- Local files, specified with a simple path string
- Remote files via HTTPS, specified with a URL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we considering this a P0? If so, why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be supported out of the box by load_dataset: https://huggingface.co/learn/nlp-course/en/chapter5/2#loading-a-remote-dataset

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. rfc Request for comments
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants