Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[7/7] Multimodal datasets (The Cauldron, LLaVA-Instruct-150K) #1158

Merged
merged 60 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
75dae87
complete tokenizer refactor
RdoubleA Jun 12, 2024
0c20ba9
move tokenizers under data/
RdoubleA Jun 12, 2024
730a2c9
fix all tests
RdoubleA Jun 12, 2024
acf7e81
Merge branch 'main' into tokenizer
RdoubleA Jun 21, 2024
2ae157c
start to address comments
RdoubleA Jun 22, 2024
6a50cd5
load in special tokens, move tokenizer directory back, address comments
RdoubleA Jun 24, 2024
61534d0
fix encode whitespace
RdoubleA Jun 24, 2024
1d6e5e3
updates after manual comparisons
RdoubleA Jun 25, 2024
5712de4
default special tokens
RdoubleA Jun 26, 2024
d84bbda
fix docs
RdoubleA Jun 26, 2024
5a8b82b
fix doc strings
RdoubleA Jun 26, 2024
52643cb
Merge branch 'main' into tokenizer
RdoubleA Jun 26, 2024
a00c1dc
fix tests
RdoubleA Jun 26, 2024
29273ca
fix SP test
RdoubleA Jun 26, 2024
aa43095
add image support
RdoubleA Jun 26, 2024
8afaaf9
tool support
RdoubleA Jun 26, 2024
d3d4b66
update tests
RdoubleA Jun 26, 2024
d326dca
update tests
RdoubleA Jun 26, 2024
58e3e9d
use images as attachments instead
RdoubleA Jun 27, 2024
7fdccae
update all tests
RdoubleA Jun 27, 2024
820d9ac
use list of dicts for MM messages
RdoubleA Jun 27, 2024
7ba4216
fix chat formats
RdoubleA Jul 1, 2024
42f8c83
add multimodal dataset, test, and the cauldron
RdoubleA Jun 26, 2024
7cad2dc
multimodal dataset test
RdoubleA Jun 27, 2024
335e85f
fix rebase
RdoubleA Jul 1, 2024
adca77e
Merge branch 'main' into tokenizer
RdoubleA Jul 2, 2024
b204563
update api ref
RdoubleA Jul 2, 2024
e236916
Merge branch 'main' into tokenizer
RdoubleA Jul 2, 2024
93028cf
fix llama3 toeknizer test:
RdoubleA Jul 2, 2024
fb12cbb
add image support
RdoubleA Jun 26, 2024
b5bf410
tool support
RdoubleA Jun 26, 2024
00f266f
update tests
RdoubleA Jun 26, 2024
c815069
update tests
RdoubleA Jun 26, 2024
21b3ea8
use images as attachments instead
RdoubleA Jun 27, 2024
adbfb20
update all tests
RdoubleA Jun 27, 2024
1e40a9d
use list of dicts for MM messages
RdoubleA Jun 27, 2024
0d3665c
fix chat formats
RdoubleA Jul 1, 2024
95edf70
run linter
RdoubleA Jul 2, 2024
a3067aa
Merge branch 'main' into tokenizer_updates
RdoubleA Jul 2, 2024
d49febf
merge main
RdoubleA Jul 2, 2024
7da4189
fix chat formats
RdoubleA Jul 3, 2024
58babf0
Merge branch 'tokenizer_updates' into mm_dataset
RdoubleA Jul 3, 2024
7bcdaf8
fix merge
RdoubleA Jul 3, 2024
82e1dea
Merge branch 'main' into mm_dataset
RdoubleA Jul 9, 2024
ff81c5c
fix merge
RdoubleA Jul 9, 2024
258e98f
multimodal dataset, unit test, and two example dataset builders with …
RdoubleA Jul 10, 2024
1410d70
Merge branch 'main' into mm_dataset
RdoubleA Jul 10, 2024
5aea048
Merge branch 'main' into mm_dataset
RdoubleA Aug 21, 2024
2731a60
update with latest APIs
RdoubleA Aug 22, 2024
8530958
fix lint
RdoubleA Aug 22, 2024
ce1fe8a
Merge remote-tracking branch 'upstream/main' into mm_dataset
RdoubleA Aug 22, 2024
ddd5e86
Merge branch 'main' into mm_dataset
RdoubleA Aug 26, 2024
939a3f5
Merge branch 'main' into mm_dataset
RdoubleA Sep 3, 2024
278649a
remove image handling
RdoubleA Sep 3, 2024
5f2021f
Merge branch 'main' into mm_dataset
RdoubleA Sep 3, 2024
bb1a8b5
separate llava transform
RdoubleA Sep 3, 2024
e16b2ad
update tests
RdoubleA Sep 3, 2024
8079294
fix tests
RdoubleA Sep 3, 2024
b006a90
Merge branch 'main' into mm_dataset
RdoubleA Sep 4, 2024
703b986
update docstrings
RdoubleA Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added tests/assets/test_image.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 23 additions & 1 deletion tests/torchtune/data/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import pytest
from torchtune.data import Message, truncate, validate_messages
from torchtune.data import Message, split_text_by_image_tag, truncate, validate_messages


def test_truncate():
Expand Down Expand Up @@ -88,3 +88,25 @@ def test_validate_messages():
match="Assistant message before expected user message at index 0 in messages",
):
validate_messages(messages)


def test_split_text_by_image_tag():
text = "hello <image>world"
assert split_text_by_image_tag(text, "<image>") == [
{"type": "text", "content": "hello "},
{"type": "image"},
{"type": "text", "content": "world"},
]

text = "[image]hello [image]world"
assert split_text_by_image_tag(text, "[image]") == [
{"type": "image"},
{"type": "text", "content": "hello "},
{"type": "image"},
{"type": "text", "content": "world"},
]

text = "hello world"
assert split_text_by_image_tag(text, "asdfghjkl;") == [
{"type": "text", "content": "hello world"}
]
138 changes: 138 additions & 0 deletions tests/torchtune/datasets/test_llava_instruct_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import Counter
from pathlib import Path
from unittest.mock import patch

import pytest
from datasets import Dataset
from PIL import Image

from tests.test_utils import DummyTokenizer
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets import llava_instruct_dataset

ASSETS = Path(__file__).parent.parent.parent / "assets"


class TestLLaVAInstructDataset:
@pytest.fixture
def tokenizer(self):
return DummyTokenizer()

@patch("torchtune.datasets._sft.load_dataset")
def test_label_no_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is not masked.
"""

# mock the call to HF datasets
load_dataset.return_value = Dataset.from_list(
[
{
"image": "test_image.jpg",
"conversations": [
{
"from": "human",
"value": "<image>\nWhat can you infer about the man's outdoor activity?",
},
{
"from": "gpt",
"value": "From the image, we can infer that the man is engaging in a "
"recreational activity involving a frisbee in a park or grass field. "
"The frisbee is in the air, and the man appears to be either catching "
"or throwing it. This suggests that he might be playing a casual game "
"of catch with a friend or practicing his frisbee skills, enjoying the "
"outdoors and getting some physical activity at the same time.",
},
],
}
]
)

ds = llava_instruct_dataset(
model_transform=tokenizer, train_on_input=True, coco_image_dir=str(ASSETS)
)
input, labels, images = ds[0]["tokens"], ds[0]["labels"], ds[0]["images"][0]

expected_count = {
3: 17,
2: 15,
4: 11,
8: 9,
5: 8,
7: 8,
6: 5,
1: 5,
9: 2,
0: 1,
-2: 1,
12: 1,
10: 1,
-1: 1,
}

assert Counter(input) == expected_count
assert Counter(labels) == expected_count
assert isinstance(images, Image.Image)

@patch("torchtune.datasets._sft.load_dataset")
def test_label_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is masked.
"""

# mock the call to HF datasets
load_dataset.return_value = Dataset.from_list(
[
{
"image": "test_image.jpg",
"conversations": [
{
"from": "human",
"value": "<image>\nWhat can you infer about the man's outdoor activity?",
},
{
"from": "gpt",
"value": "From the image, we can infer that the man is engaging in a "
"recreational activity involving a frisbee in a park or grass field. "
"The frisbee is in the air, and the man appears to be either catching "
"or throwing it. This suggests that he might be playing a casual game "
"of catch with a friend or practicing his frisbee skills, enjoying the "
"outdoors and getting some physical activity at the same time.",
},
],
}
]
)

ds = llava_instruct_dataset(
model_transform=tokenizer, train_on_input=False, coco_image_dir=str(ASSETS)
)
input, labels, images = ds[0]["tokens"], ds[0]["labels"], ds[0]["images"][0]

expected_count = {
3: 17,
2: 15,
4: 11,
8: 9,
5: 8,
7: 8,
6: 5,
1: 5,
9: 2,
0: 1,
-2: 1,
12: 1,
10: 1,
-1: 1,
}

assert Counter(input) == expected_count
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 11
assert isinstance(images, Image.Image)
157 changes: 157 additions & 0 deletions tests/torchtune/datasets/test_the_cauldron_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pdb
from unittest.mock import patch

import pytest
import torch
from datasets import Dataset

from tests.test_utils import DummyTokenizer
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets import the_cauldron_dataset
from torchvision.transforms import PILToTensor, ToPILImage


class TestTheCauldronDataset:
@pytest.fixture
def tokenizer(self):
return DummyTokenizer()

@patch("torchtune.datasets._sft.load_dataset")
def test_label_no_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is not masked.
"""

image_tensor = torch.randint(0, 256, (3, 4, 4), dtype=torch.uint8)
# mock the call to HF datasets
load_dataset.return_value = Dataset.from_list(
[
{
"images": [ToPILImage()(image_tensor)],
"texts": [
{
"user": "Question: What do respiration and combustion give out"
"\nChoices:\nA. Oxygen\nB. Carbon dioxide\nC. Nitrogen\nD. Heat"
"\nAnswer with the letter.",
"assistant": "Answer: B",
"source": "AI2D",
}
],
}
]
)

ds = the_cauldron_dataset(
model_transform=tokenizer, subset="dummy", train_on_input=True
)
input, labels, images = (
ds[0]["tokens"],
ds[0]["labels"],
ds[0]["images"][0],
)

assert input == [
0,
-2,
9,
4,
2,
11,
3,
10,
4,
3,
8,
2,
6,
2,
6,
7,
2,
8,
2,
4,
6,
4,
3,
7,
7,
1,
-1,
]
assert labels == input
pdb.set_trace()
torch.testing.assert_close(PILToTensor()(images), image_tensor)

@patch("torchtune.datasets._sft.load_dataset")
def test_label_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is masked.
"""

image_tensor = torch.randint(0, 256, (3, 4, 4), dtype=torch.uint8)
# mock the call to HF datasets
load_dataset.return_value = Dataset.from_list(
[
{
"images": [ToPILImage()(image_tensor)],
"texts": [
{
"user": "Question: What do respiration and combustion give out"
"\nChoices:\nA. Oxygen\nB. Carbon dioxide\nC. Nitrogen\nD. Heat"
"\nAnswer with the letter.",
"assistant": "Answer: B",
"source": "AI2D",
}
],
}
]
)

ds = the_cauldron_dataset(
model_transform=tokenizer, subset="dummy", train_on_input=False
)
input, labels, images = (
ds[0]["tokens"],
ds[0]["labels"],
ds[0]["images"][0],
)

assert input == [
0,
-2,
9,
4,
2,
11,
3,
10,
4,
3,
8,
2,
6,
2,
6,
7,
2,
8,
2,
4,
6,
4,
3,
7,
7,
1,
-1,
]
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 24
torch.testing.assert_close(PILToTensor()(images), image_tensor)
4 changes: 3 additions & 1 deletion torchtune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Message,
Role,
ShareGPTToMessages,
validate_messages,
)
from torchtune.data._prompt_templates import (
ChatMLTemplate,
Expand All @@ -29,7 +30,7 @@
QuestionAnswerTemplate,
SummarizeTemplate,
)
from torchtune.data._utils import truncate, validate_messages
from torchtune.data._utils import split_text_by_image_tag, truncate

__all__ = [
"ChatFormat",
Expand All @@ -46,6 +47,7 @@
"Message",
"validate_messages",
"Role",
"split_text_by_image_tag",
"PromptTemplateInterface",
"PromptTemplate",
"InputOutputToMessages",
Expand Down
Loading
Loading