-
Notifications
You must be signed in to change notification settings - Fork 404
/
test_utils.py
221 lines (188 loc) · 6.49 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
import os
import re
import sys
import unittest
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
from typing import Any, Dict, Generator, Optional, TextIO, Tuple, Union
import pytest
import torch
from torch import nn
from torchtune.modules.tokenizers import SentencePieceTokenizer
skip_if_cuda_not_available = unittest.skipIf(
not torch.cuda.is_available(), "CUDA is not available"
)
CKPT_MODEL_PATHS = {
"llama2_tune": "/tmp/test-artifacts/small-ckpt-tune-03082024.pt",
"llama2_meta": "/tmp/test-artifacts/small-ckpt-meta-03082024.pt",
"llama2_hf": "/tmp/test-artifacts/small-ckpt-hf-03082024.pt",
"llama3_tune": "/tmp/test-artifacts/small-ckpt-tune-llama3-05052024.pt",
"llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt",
}
TOKENIZER_PATHS = {
"llama2": "/tmp/test-artifacts/tokenizer.model",
"llama3": "/tmp/test-artifacts/tokenizer_llama3.model",
}
def torch_version_ge(version: str) -> bool:
"""
Check if torch version is greater than or equal to the given version
"""
return version in torch.__version__ or torch.__version__ >= version
# Inherit from SentencePieceTokenizer class to reuse its tokenize_messages method
class DummyTokenizer(SentencePieceTokenizer):
def __init__(self):
self.encodes_whitespace = False
def encode(self, text, add_bos=True, add_eos=True, **kwargs):
words = text.split()
tokens = [len(word) for word in words]
if add_bos:
tokens = [self.bos_id] + tokens
if add_eos:
tokens = tokens + [self.eos_id]
return tokens
@property
def eos_id(self):
return -1
@property
def bos_id(self):
return 0
def get_assets_path():
return Path(__file__).parent / "assets"
def fixed_init_tensor(
shape: torch.Size,
min_val: Union[float, int] = 0.0,
max_val: Union[float, int] = 1.0,
nonlinear: bool = False,
dtype: torch.dtype = torch.float,
):
"""
Utility for generating deterministic tensors of a given shape. In general stuff
like torch.ones, torch.eye, etc can result in trivial outputs. This utility
generates a range tensor [min_val, max_val) of a specified dtype, applies
a sine function if nonlinear=True, then reshapes to the appropriate shape.
"""
n_elements = math.prod(shape)
step_size = (max_val - min_val) / n_elements
x = torch.arange(min_val, max_val, step_size, dtype=dtype)
x = x.reshape(shape)
if nonlinear:
return torch.sin(x)
return x
@torch.no_grad
def fixed_init_model(
model: nn.Module,
min_val: Union[float, int] = 0.0,
max_val: Union[float, int] = 1.0,
nonlinear: bool = False,
dtype: Optional[torch.dtype] = None,
):
"""
This utility initializes all parameters of a model deterministically using the
function fixed_init_tensor above. See that docstring for details of each parameter.
"""
for _, param in model.named_parameters():
param.copy_(
fixed_init_tensor(
param.shape,
min_val=min_val,
max_val=max_val,
nonlinear=nonlinear,
dtype=param.dtype if dtype is None else dtype,
)
)
def assert_expected(
actual: Any,
expected: Any,
rtol: float = 1e-5,
atol: float = 1e-8,
check_device: bool = True,
):
torch.testing.assert_close(
actual,
expected,
rtol=rtol,
atol=atol,
check_device=check_device,
msg=f"actual: {actual}, expected: {expected}",
)
@contextmanager
def single_box_init(init_pg: bool = True):
env_vars = ["MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK", "RANK", "WORLD_SIZE"]
initial_os = {k: os.environ.get(k, None) for k in env_vars}
os.environ.get("MASTER_ADDR", None)
os.environ["MASTER_ADDR"] = "localhost"
# TODO: Don't hardcode ports as this could cause flakiness if tests execute
# in parallel.
os.environ["MASTER_PORT"] = str(12345)
os.environ["LOCAL_RANK"] = str(0)
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
if init_pg:
torch.distributed.init_process_group(
backend="gloo",
world_size=int(os.environ["WORLD_SIZE"]),
rank=int(os.environ["RANK"]),
)
try:
yield
finally:
if init_pg:
torch.distributed.destroy_process_group()
for k in env_vars:
if initial_os.get(k) is None:
del os.environ[k]
else:
os.environ[k] = initial_os[k]
@contextmanager
def set_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(old_dtype)
@contextmanager
def captured_output() -> Generator[Tuple[TextIO, TextIO], None, None]:
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err
def gpu_test(gpu_count: int = 1):
"""
Annotation for GPU tests, skipping the test if the
required amount of GPU is not available
"""
message = f"Not enough GPUs to run the test: requires {gpu_count}"
local_gpu_count: int = torch.cuda.device_count()
return pytest.mark.skipif(local_gpu_count < gpu_count, reason=message)
def get_loss_values_from_metric_logger(log_file_path: str) -> Dict[str, float]:
"""
Given an output directory containing metric logger .txt file,
parse the .txt and return a list of losses from each logged iteration.
"""
with open(log_file_path, "r") as f:
logs = f.read()
losses = [float(x) for x in re.findall(r"loss:(\d+\.\d+)", logs)]
return losses
def gen_log_file_name(tmpdir, suffix: Optional[str] = None) -> str:
"""
Take the tmpdir and just append a non-path version of it as the
filename, optionally adding specified suffix. This is used to
write metric logs to a deterministic file per test run.
E.g. /tmp/my/dir -> /tmp/my/dir/tmpmydir.txt
"""
filename = str(tmpdir) + str(tmpdir).replace("/", "")
if suffix:
filename += suffix
filename += ".txt"
return filename