Skip to content

Commit

Permalink
Add support for int4 weight-only QAT
Browse files Browse the repository at this point in the history
Summary: This commit adds support for int4 weight-only QAT,
which simulates the numerics of the existing
Int4WeightOnlyQuantizer. The main motivation for this is to
provide an end-to-end path for running QAT and lowering to
the efficient int4 tinygemm cuda kernel. To enable this,
we have to add new fake quantization primitives to match
the numerics of the tinygemm kernel, and this required
refactoring existing quant primitives to skip dtype casting.

Test Plan:
python test/quantization/test_qat.py -k test_qat_4w_linear

Reviewers: jerryzh168, msaroufim

Subscribers: jerryzh168, msaroufim, HDCharles, supriyar
  • Loading branch information
andrewor14 committed Jun 16, 2024
1 parent 664f073 commit 2ac2250
Show file tree
Hide file tree
Showing 5 changed files with 529 additions and 51 deletions.
184 changes: 166 additions & 18 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,39 @@
fake_quantize_per_channel_group,
fake_quantize_per_token,
)
from torchao.quantization.utils import get_group_qparams_symmetric
from torchao.quantization.utils import (
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
groupwise_affine_quantize_tensor,
groupwise_affine_quantize_tensor_from_qparams,
)
from torchao.utils import TORCH_VERSION_AFTER_2_4


# TODO: put this in a common test utils file
_CUDA_IS_AVAILABLE = torch.cuda.is_available()

class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 32, bias=False).to(torch.float)
self.linear = torch.nn.Linear(256, 256, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 32).to(torch.float),)
return (torch.randn(1, 256).to(torch.float),)

def forward(self, x):
return self.linear(x)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
self.linear1 = torch.nn.Linear(512, 256, bias=False).to(torch.float)
self.sub = Sub()
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 64).to(torch.float),)
return (torch.randn(1, 512).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
Expand Down Expand Up @@ -111,23 +119,46 @@ def test_fake_quantize_per_token(self):

def _set_ptq_weight(
self,
ptq_linear: "Int8DynActInt4WeightLinear",
fp32_weight: torch.Tensor,
group_size: int,
ptq_linear: torch.nn.Module,
qat_linear: torch.nn.Module,
):
"""
Set the weight to the quantized version of the given fp32 weights,
for making linear outputs comparable with QAT.
"""
from torchao.quantization.GPTQ import (
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
from torchao.quantization.prototype.qat import (
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear,
)
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
)
ptq_linear.weight = q_weight
ptq_linear.scales = s
ptq_linear.zeros = zp
if isinstance(ptq_linear, Int8DynActInt4WeightLinear):
assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear)
fp32_weight = qat_linear.weight
group_size = qat_linear.groupsize
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
)
ptq_linear.weight = q_weight
ptq_linear.scales = s
ptq_linear.zeros = zp
elif isinstance(ptq_linear, WeightOnlyInt4Linear):
assert isinstance(qat_linear, Int4WeightOnlyQATLinear)
(q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
qat_linear.weight, n_bit, qat_linear.groupsize,
)
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to("cuda"), qat_linear.inner_k_tiles,
)
ptq_linear.weight = q_weight
ptq_linear.scales_and_zeros = scales_and_zeros
else:
raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
Expand All @@ -144,7 +175,7 @@ def test_qat_8da4w_linear(self):
)

# Force the weights to be the same
self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size)
self._set_ptq_weight(ptq_linear, qat_linear)

# Compare linear values
torch.manual_seed(self.SEED)
Expand Down Expand Up @@ -280,7 +311,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
loss_fn1 = torch.nn.CrossEntropyLoss()
loss_fn2 = torch.nn.CrossEntropyLoss()
example_inputs = nn_model.example_inputs()
target = torch.randn(1, 64).float()
target = torch.randn(1, 512).float()
output1 = nn_model(*example_inputs)
output2 = qat_model(*example_inputs)
torch.testing.assert_close(output1, output2, atol=0, rtol=0)
Expand Down Expand Up @@ -322,6 +353,123 @@ def test_qat_generic_fake_quantize(self):
torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0)

def _assert_close_4w(self, val, ref):
# Note: for int4 weight-only quantization, we do not expect exact match
# because torch._weight_int4pack_mm and torch.mm do not match exactly.
# Here we use the same error bar as PyTorch core to determine closeness:
# https:/pytorch/pytorch/blob/6079c5091091d872b8dafbaa4e31a5b6194647ad/test/test_linalg.py#L6079
mean_err = ((val - ref) / ref).mean().abs()
print(mean_err)
self.assertTrue(mean_err < 0.05)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_primitives(self):
n_bit = 4
group_size = 32
inner_k_tiles = 8
scales_precision = torch.bfloat16
device = torch.device("cuda")
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
x = torch.randn(100, 256, dtype=dtype, device=device)
weight = torch.randn(512, 256, dtype=dtype, device=device)

# PTQ
(q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
weight, n_bit, group_size, scales_precision,
)
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to(device), inner_k_tiles,
)
ptq_out = torch.ops.aten._weight_int4pack_mm(
x, q_weight, group_size, scales_and_zeros
)

# QAT
scales, zero_points = get_groupwise_affine_qparams(
weight, n_bit, group_size, scales_precision,
)
w_q = groupwise_affine_quantize_tensor_from_qparams(
weight, scales, zero_points, n_bit, group_size, cast_dtypes=False,
)
w_dq = groupwise_affine_dequantize_tensor_from_qparams(
w_q, scales, zero_points, n_bit, group_size, cast_dtypes=False,
)
qat_out = torch.nn.functional.linear(x, w_dq)

self._assert_close_4w(qat_out, ptq_out)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear

group_size = 128
device = torch.device("cuda")
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
qat_linear = Int4WeightOnlyQATLinear(
256, 688, bias=False, groupsize=group_size, device=device,
)
ptq_linear = WeightOnlyInt4Linear(
256, 688, bias=False, groupsize=group_size, device=device,
)

# Force the weights to be the same
self._set_ptq_weight(ptq_linear, qat_linear)

# Compare linear values
torch.manual_seed(self.SEED)
x = torch.randn(100, 256, dtype=dtype, device=device)
x2 = copy.deepcopy(x)
qat_out = qat_linear(x)
ptq_out = ptq_linear(x2)
self._assert_close_4w(qat_out, ptq_out)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer

group_size = 32
inner_k_tiles = 8
device = torch.device("cuda")
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
m = M().to(device).to(dtype)
m2 = copy.deepcopy(m)
qat_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
ptq_quantizer = Int4WeightOnlyQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)

# Compare model values
torch.manual_seed(self.SEED)
x = [i.to(device).to(dtype) for i in m.example_inputs()]
x2 = copy.deepcopy(x)
qat_out = qat_model(*x)
ptq_out = ptq_model(*x2)
self._assert_close_4w(qat_out, ptq_out)

# Convert QAT model and compare model values
converted_model = qat_quantizer.convert(qat_model)
converted_out = converted_model(*x)
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 2ac2250

Please sign in to comment.