Skip to content

Commit

Permalink
improve logic, split samples, add test on real data
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA committed Apr 27, 2024
1 parent 8df0898 commit d760d1d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 15 deletions.
51 changes: 47 additions & 4 deletions tests/torchtune/datasets/test_packed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import pytest
from tests.test_utils import DummyTokenizer
from torch.utils.data import Dataset

from torchtune.datasets import PackedDataset
Expand All @@ -20,24 +21,66 @@ def __getitem__(self, index):
return [index] * self.sample_size, [index] * self.sample_size


class DummyRealDataset(Dataset):
def __init__(self):
self.samples_list = [
"This is a packing test",
"A fantastic test. It should pack two samples.",
"This one will not be fully packed.",
]
self.tokenizer = DummyTokenizer()

def __getitem__(self, index):
tokens = self.tokenizer.encode(self.samples_list[index])
return tokens, tokens


class TestPackedDataset:
@pytest.mark.parametrize("max_seq_len", [10, 25])
@pytest.mark.parametrize("sample_size", [2, 5])
@pytest.mark.parametrize("max_rows", [5, 10])
def test_packed_dataset(self, max_seq_len, sample_size, max_rows):
@pytest.mark.parametrize("split_samples", [True, False])
def test_packed_dataset(self, max_seq_len, sample_size, max_rows, split_samples):
dataset = DummyDataset(sample_size)
packed = PackedDataset(
dataset,
max_seq_len=max_seq_len,
max_rows=max_rows,
split_samples=split_samples,
)
# Check we get right number of packs
assert len(packed) == max_rows
# Check input ids and labels are same length
assert len(packed[0][0]) == len(packed[0][1])
# Check that samples are packed correctly - very last individual sample
# should have index value of the number of times dataset was iterated over
last_index, remainder = divmod(max_rows * max_seq_len, sample_size)
if remainder > 0:
last_index += 1
if split_samples:
# If we split samples, we'll know how many samples by taking the
# full length and dividing by sample size
last_index, remainder = divmod(max_rows * max_seq_len, sample_size)
# Account for remaining sample that didn't fit in window
last_index = last_index + 1 if remainder > 0 else last_index
else:
# If we don't split samples, we know how many samples by taking
# how much fits in a single window and multiplying by max rows.
# We don't account for remainder sample because we'll hit max rows.
last_index = (max_seq_len // sample_size) * max_rows

assert packed[-1][0][-1] == last_index - 1

def test_packed_dataset_real_data(self):
expected_tokenized_prompts = [
[0, 4, 2, 1, 7, 4, -1, 0, 1, 9],
[5, 2, 6, 4, 3, 8, -1, 0, 4, 3],
[4, 3, 2, 5, 7, -1],
]
packed = PackedDataset(
DummyRealDataset(),
max_seq_len=10,
split_samples=True,
)

for i in range(len(packed)):
prompt, label = packed[i]
assert prompt == expected_tokenized_prompts[i]
assert label == expected_tokenized_prompts[i]
57 changes: 46 additions & 11 deletions torchtune/datasets/_packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,24 @@ class on packed samples as part of the dataloader.
inputs and labels.
max_seq_len (int): Maximum number of tokens to pack
max_rows (Optional[int]): maximum number of samples to pack. Default is None, which will pack as many samples as possible.
split_samples (bool): if the last sample in a pack does not fit in ``max_seq_len``,
split the sample into the next pack, or move it to the beginning of the next pack.
For pre-training, typically this is set to True for general text completion. For
fine-tuning, typically this is set to False to avoid truncating sentences in instruct
tuning. Default is False.
"""

def __init__(
self,
ds: Dataset,
max_seq_len: int,
max_rows: Optional[int] = None,
split_samples: bool = False,
) -> None:
self.ds = ds
self.max_seq_len = max_seq_len
self.max_rows = max_rows
self.split_samples = split_samples
# where final samples will be held
self.samples: List[Dict[str, List[int]]] = []
self._pack()
Expand All @@ -50,26 +57,54 @@ def _pack(self) -> None:
until max_rows or end of dataset.
"""
# buffer to hold samples until they are long enough to be added to self.samples
buffer = {
current_pack = {
"input_ids": [],
"labels": [],
}
# Keep track of what index the previous sample ends in case we need
# to end a pack early
previous_sample_boundary = 0

for input_ids, labels in tqdm(
self.ds, desc="Packing dataset", dynamic_ncols=True
):
buffer["input_ids"].extend(input_ids)
buffer["labels"].extend(labels)
# If the dataset outputs samples that are larger than the specified
# max_seq_len and we're unable to split it, user needs to modify
# one of the two parameters
if len(input_ids) > self.max_seq_len and not self.split_samples:
raise ValueError(
f"Dataset sample is too long ({len(input_ids)} > {self.max_seq_len}). "
"Please set `split_samples=True` or increase `max_seq_len`."
)

current_pack["input_ids"].extend(input_ids)
current_pack["labels"].extend(labels)

# If buffer has reached max_seq_len, append packed sample
while len(buffer["input_ids"]) > self.max_seq_len:
self.samples.append(
{k: v[: self.max_seq_len] for k, v in buffer.items()}
if len(current_pack["input_ids"]) > self.max_seq_len:
current_pack = self._add_pack(
current_pack=current_pack,
boundary=self.max_seq_len
if self.split_samples
else previous_sample_boundary,
)
buffer = {k: v[self.max_seq_len :] for k, v in buffer.items()}
assert len(buffer["input_ids"]) == len(buffer["labels"])
if self.max_rows is not None and len(self.samples) >= self.max_rows:
return

previous_sample_boundary = len(current_pack["input_ids"])
if self.max_rows is not None and len(self.samples) >= self.max_rows:
break

if len(current_pack["input_ids"]) > 0 and (
self.max_rows is None or len(self.samples) < self.max_rows
):
self.samples.append(dict(current_pack))

def _add_pack(
self, current_pack: Dict[str, List[int]], boundary: int
) -> Dict[str, List[int]]:
"""
Add the current pack to self.samples and return what's remaining of the pack.
"""
self.samples.append({k: v[:boundary] for k, v in current_pack.items()})
return {k: v[boundary:] for k, v in current_pack.items()}

def __len__(self):
return len(self.samples)
Expand Down

0 comments on commit d760d1d

Please sign in to comment.