-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_autocompile.py
129 lines (96 loc) · 4.18 KB
/
test_autocompile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from torch import nn
from autocompile import ModuleCompiler
def test_static_shape():
class StaticModel(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * 2
model = StaticModel().eval()
compiler = ModuleCompiler(model)
# Run the model multiple times with the same input shape
inputs = [(torch.randn(1, 3),), (torch.randn(1, 3),)]
compiler.run_model_many(model, inputs)
modules_to_compile = compiler.determine_modules_to_compile()
assert len(modules_to_compile) == 1
trt_inputs = modules_to_compile[""]
assert len(trt_inputs) == 1
input_shape = trt_inputs[0].shape
assert input_shape == torch.Size([1, 3])
print("Static shape test passed.")
def test_dynamic_shape():
class DynamicModel(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + 1
model = DynamicModel().eval()
compiler = ModuleCompiler(model)
# Run the model multiple times with different input shapes
inputs = [(torch.randn(i, 3),) for i in range(1, 5)]
compiler.run_model_many(model, inputs)
modules_to_compile = compiler.determine_modules_to_compile()
assert len(modules_to_compile) == 1
trt_inputs = modules_to_compile[""]
assert len(trt_inputs) == 1
trt_input = trt_inputs[0]
expected = {"min_shape": (1, 3), "opt_shape": (4, 3), "max_shape": (4, 3)}
assert trt_input.shape == expected
print("Dynamic shape test passed.")
def test_complex_pipeline():
class Tokenizer:
def __call__(self, texts: list[str]) -> torch.Tensor:
return torch.tensor([[len(text)] for text in texts], dtype=torch.float32)
class ImageModel(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * 2
class BadProcessor(nn.Module):
def forward(self, x: torch.Tensor, count: int) -> torch.Tensor:
return torch.stack([x] * count)
class UpscalerModel(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + 1
class ImagePipeline:
def __init__(self):
self.tokenizer = Tokenizer()
self.model = ImageModel()
def generate(self, prompt: str, n: int) -> torch.Tensor:
tokens = self.tokenizer([prompt] * n)
return self.model(tokens)
class UpscalerPipeline:
def __init__(self):
self.image_processor = BadProcessor()
self.model = UpscalerModel()
def upscale(self, images: torch.Tensor) -> torch.Tensor:
processed_images = self.image_processor(images, 3)
return self.model(processed_images)
class Predictor:
def __init__(self):
self.image_pipe = ImagePipeline()
self.upscaler = UpscalerPipeline()
def predict(self, prompt: str, n: int) -> torch.Tensor:
images = self.image_pipe.generate(prompt, n)
return self.upscaler.upscale(images)
predictor = Predictor()
compiler = ModuleCompiler(predictor)
# Run the model multiple times with different inputs
args_list = [("hello", 2), ("world", 3)]
compiler.run_model_many(predictor.predict, args_list)
modules_to_compile = compiler.determine_modules_to_compile()
expected_modules = {"image_pipe.model", "upscaler.model"}
assert set(modules_to_compile.keys()) == expected_modules
# Check the input shapes for 'image_pipe.model'
trt_inputs_image = modules_to_compile["image_pipe.model"]
assert len(trt_inputs_image) == 1
trt_input_image = trt_inputs_image[0]
expected = {"min_shape": (2, 1), "opt_shape": (3, 1), "max_shape": (3, 1)}
assert trt_input_image.shape == expected
# Check the input shapes for 'upscaler.model'
trt_inputs_upscaler = modules_to_compile["upscaler.model"]
assert len(trt_inputs_upscaler) == 1
trt_input_upscaler = trt_inputs_upscaler[0]
expected = {"min_shape": (3, 2, 1), "opt_shape": (3, 3, 1), "max_shape": (3, 3, 1)}
assert trt_input_upscaler.shape == expected
print("Complex pipeline test passed.")
if __name__ == "__main__":
# Run tests manually
test_static_shape()
test_dynamic_shape()
test_complex_pipeline()