diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 433fdcb28..6e2859e10 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -18,18 +18,26 @@ fake_quantize_per_channel_group, fake_quantize_per_token, ) -from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.quantization.utils import ( + get_group_qparams_symmetric, + get_groupwise_affine_qparams, + groupwise_affine_dequantize_tensor_from_qparams, + groupwise_affine_quantize_tensor, + groupwise_affine_quantize_tensor_from_qparams, +) from torchao.utils import TORCH_VERSION_AFTER_2_4 # TODO: put this in a common test utils file +_CUDA_IS_AVAILABLE = torch.cuda.is_available() + class Sub(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(32, 32, bias=False).to(torch.float) + self.linear = torch.nn.Linear(256, 256, bias=False).to(torch.float) def example_inputs(self): - return (torch.randn(1, 32).to(torch.float),) + return (torch.randn(1, 256).to(torch.float),) def forward(self, x): return self.linear(x) @@ -37,12 +45,12 @@ def forward(self, x): class M(torch.nn.Module): def __init__(self): super().__init__() - self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) + self.linear1 = torch.nn.Linear(512, 256, bias=False).to(torch.float) self.sub = Sub() - self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float) + self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float) def example_inputs(self): - return (torch.randn(1, 64).to(torch.float),) + return (torch.randn(1, 512).to(torch.float),) def forward(self, x): x = self.linear1(x) @@ -111,23 +119,46 @@ def test_fake_quantize_per_token(self): def _set_ptq_weight( self, - ptq_linear: "Int8DynActInt4WeightLinear", - fp32_weight: torch.Tensor, - group_size: int, + ptq_linear: torch.nn.Module, + qat_linear: torch.nn.Module, ): """ Set the weight to the quantized version of the given fp32 weights, for making linear outputs comparable with QAT. """ + from torchao.quantization.GPTQ import ( + Int8DynActInt4WeightLinear, + WeightOnlyInt4Linear, + ) + from torchao.quantization.prototype.qat import ( + Int8DynActInt4WeightQATLinear, + Int4WeightOnlyQATLinear, + ) n_bit = 4 (qmin, qmax) = self._get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) - q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( - fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, - ) - ptq_linear.weight = q_weight - ptq_linear.scales = s - ptq_linear.zeros = zp + if isinstance(ptq_linear, Int8DynActInt4WeightLinear): + assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear) + fp32_weight = qat_linear.weight + group_size = qat_linear.groupsize + (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) + q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( + fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, + ) + ptq_linear.weight = q_weight + ptq_linear.scales = s + ptq_linear.zeros = zp + elif isinstance(ptq_linear, WeightOnlyInt4Linear): + assert isinstance(qat_linear, Int4WeightOnlyQATLinear) + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + qat_linear.weight, n_bit, qat_linear.groupsize, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to("cuda"), qat_linear.inner_k_tiles, + ) + ptq_linear.weight = q_weight + ptq_linear.scales_and_zeros = scales_and_zeros + else: + raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear)) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): @@ -144,7 +175,7 @@ def test_qat_8da4w_linear(self): ) # Force the weights to be the same - self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size) + self._set_ptq_weight(ptq_linear, qat_linear) # Compare linear values torch.manual_seed(self.SEED) @@ -280,7 +311,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): loss_fn1 = torch.nn.CrossEntropyLoss() loss_fn2 = torch.nn.CrossEntropyLoss() example_inputs = nn_model.example_inputs() - target = torch.randn(1, 64).float() + target = torch.randn(1, 512).float() output1 = nn_model(*example_inputs) output2 = qat_model(*example_inputs) torch.testing.assert_close(output1, output2, atol=0, rtol=0) @@ -322,6 +353,123 @@ def test_qat_generic_fake_quantize(self): torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0) torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0) + def _assert_close_4w(self, val, ref): + # Note: for int4 weight-only quantization, we do not expect exact match + # because torch._weight_int4pack_mm and torch.mm do not match exactly. + # Here we use the same error bar as PyTorch core to determine closeness: + # https://github.com/pytorch/pytorch/blob/6079c5091091d872b8dafbaa4e31a5b6194647ad/test/test_linalg.py#L6079 + mean_err = ((val - ref) / ref).mean().abs() + print(mean_err) + self.assertTrue(mean_err < 0.05) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_qat_4w_primitives(self): + n_bit = 4 + group_size = 32 + inner_k_tiles = 8 + scales_precision = torch.bfloat16 + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(self.SEED) + x = torch.randn(100, 256, dtype=dtype, device=device) + weight = torch.randn(512, 256, dtype=dtype, device=device) + + # PTQ + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + weight, n_bit, group_size, scales_precision, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(device), inner_k_tiles, + ) + ptq_out = torch.ops.aten._weight_int4pack_mm( + x, q_weight, group_size, scales_and_zeros + ) + + # QAT + scales, zero_points = get_groupwise_affine_qparams( + weight, n_bit, group_size, scales_precision, + ) + w_q = groupwise_affine_quantize_tensor_from_qparams( + weight, scales, zero_points, n_bit, group_size, cast_dtypes=False, + ) + w_dq = groupwise_affine_dequantize_tensor_from_qparams( + w_q, scales, zero_points, n_bit, group_size, cast_dtypes=False, + ) + qat_out = torch.nn.functional.linear(x, w_dq) + + self._assert_close_4w(qat_out, ptq_out) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_qat_4w_linear(self): + from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear + from torchao.quantization.GPTQ import WeightOnlyInt4Linear + + group_size = 128 + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(self.SEED) + qat_linear = Int4WeightOnlyQATLinear( + 256, 688, bias=False, groupsize=group_size, device=device, + ) + ptq_linear = WeightOnlyInt4Linear( + 256, 688, bias=False, groupsize=group_size, device=device, + ) + + # Force the weights to be the same + self._set_ptq_weight(ptq_linear, qat_linear) + + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256, dtype=dtype, device=device) + x2 = copy.deepcopy(x) + qat_out = qat_linear(x) + ptq_out = ptq_linear(x2) + self._assert_close_4w(qat_out, ptq_out) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_qat_4w_quantizer(self): + from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + + group_size = 32 + inner_k_tiles = 8 + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(self.SEED) + m = M().to(device).to(dtype) + m2 = copy.deepcopy(m) + qat_quantizer = Int4WeightOnlyQATQuantizer( + groupsize=group_size, inner_k_tiles=inner_k_tiles, + ) + ptq_quantizer = Int4WeightOnlyQuantizer( + groupsize=group_size, inner_k_tiles=inner_k_tiles, + ) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) + + # Compare model values + torch.manual_seed(self.SEED) + x = [i.to(device).to(dtype) for i in m.example_inputs()] + x2 = copy.deepcopy(x) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + self._assert_close_4w(qat_out, ptq_out) + + # Convert QAT model and compare model values + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 996d812ba..33f573e5f 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -9,7 +9,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Optional, List, Type +from typing import Optional, Callable, List, Type import torch @@ -522,14 +522,22 @@ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles return k_divisible_by_groupsize -def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): +def linear_forward_int4( + x: torch.Tensor, + weight_int4pack: torch.Tensor, + scales_and_zeros: torch.Tensor, + out_features: int, + groupsize: int, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, +): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) c = torch.ops.aten._weight_int4pack_mm( - x.to(torch.bfloat16), + x.to(precision), weight_int4pack, groupsize, - scales_and_zeros.to(torch.bfloat16) + scales_and_zeros.to(scales_precision) ).to(dtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) @@ -544,6 +552,7 @@ class WeightOnlyInt4Linear(torch.nn.Module): def __init__( self, in_features: int, out_features: int, bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, + precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, ) -> None: super().__init__() self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) @@ -555,40 +564,92 @@ def __init__( self.in_features = in_features self.out_features = out_features assert not bias, "require bias=False" + self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles + self.precision = precision + self.scales_precision = scales_precision assert out_features % 8 == 0, "require out_features % 8 == 0" assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" self.register_buffer( "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32, device=device) ) self.register_buffer( "scales_and_zeros", - torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) + torch.empty((in_features // groupsize, out_features, 2), dtype=self.scales_precision, device=device) ) def forward(self, input: torch.Tensor) -> torch.Tensor: if self.padding: - import torch.nn.functional as F input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) return linear_forward_int4( input, - self.weight, self.scales_and_zeros, self.out_features, self.groupsize + self.weight, + self.scales_and_zeros, + self.out_features, + self.groupsize, + self.precision, + self.scales_precision, ) -def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None): +def _replace_linear_int4( + module: torch.nn.Module, + groupsize: int, + inner_k_tiles: Optional[int], + padding_allowed: bool, + skip_layer_func: Optional[Callable] = None, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + linear_class: Type[torch.nn.Module] = WeightOnlyInt4Linear, + copy_weights: bool = False, +): for name, child in module.named_children(): if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)): if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: - setattr(module, name, WeightOnlyInt4Linear( - child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, - )) + new_linear = linear_class( + child.in_features, + child.out_features, + bias=False, + device=child.weight.device, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + precision=precision, + scales_precision=scales_precision, + ) + # TODO: merge with 8da4w? + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if copy_weights and child.weight.device != torch.device("meta"): + new_linear.weight = child.weight + setattr(module, name, new_linear) else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func) + _replace_linear_int4( + child, + groupsize, + inner_k_tiles, + padding_allowed, + skip_layer_func, + precision, + scales_precision, + linear_class, + copy_weights, + ) + + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None): + _replace_linear_int4( + module, + groupsize, + inner_k_tiles, + padding_allowed, + skip_layer_func, + linear_class=WeightOnlyInt4Linear, + ) + class Int4WeightOnlyQuantizer(Quantizer): def __init__( @@ -646,6 +707,7 @@ def _create_quantized_state_dict( 4, # n_bit self.groupsize, ) + # TODO: just get the device from mod.weight.device? weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(self.device) @@ -669,6 +731,7 @@ def quantize( model.load_state_dict(state_dict, strict=False) return model + class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer): def __init__( self, @@ -1001,6 +1064,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: self.groupsize, self.padding_allowed, self.precision, + # TODO: this should be self.scales_precision? self.precision, ) return model @@ -1086,6 +1150,7 @@ def _convert_for_runtime(self, model): self.groupsize, self.padding_allowed, self.precision, + # TODO: this should be self.scales_precision? self.precision, ) return model diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index 71b585b15..2d86f79c6 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -4,20 +4,34 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Tuple +from typing import Any, Optional, Tuple import torch +import torch.nn.functional as F from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib from torch.library import impl -from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.quantization.utils import ( + get_group_qparams_symmetric, + groupwise_affine_dequantize_tensor, +) from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.GPTQ import ( + _check_linear_int4_k, + _replace_linear_int4, _replace_linear_8da4w, + get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor, + groupwise_affine_quantize_tensor_from_qparams, + groupwise_affine_dequantize_tensor_from_qparams, Int8DynActInt4WeightLinear, + WeightOnlyInt4Linear, ) +# ================= +# | 8da4w QAT | +# ================= class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): """ @@ -171,7 +185,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) else: w_fq = self.weight - return torch.nn.functional.linear(x_fq, w_fq) + return F.linear(x_fq, w_fq) # TODO: move this to common util def _get_qmin_qmax(self, n_bit: int): @@ -193,10 +207,194 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module): if isinstance(mod, Int8DynActInt4WeightQATLinear): mod.disable_fake_quant() + +# ================== +# | int4wo QAT | +# ================== + +class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have + int4 fake quantized grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + inner_k_tiles: Optional[int] = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + self.inner_k_tiles = inner_k_tiles + self.groupsize = groupsize + self.precision = precision + self.scales_precision = scales_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + padding_allowed=True, + precision=self.precision, + scales_precision=self.scales_precision, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _convert_qat_linear_4w(model) + return model + +def _convert_qat_linear_4w(module: torch.nn.Module): + """ + Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int4WeightOnlyQATLinear): + in_features = child.in_features + out_features = child.out_features + groupsize = child.groupsize + inner_k_tiles = child.inner_k_tiles + quantized_linear = WeightOnlyInt4Linear( + in_features, + out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + precision=child.precision, + scales_precision=child.scales_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + child.weight, n_bit, child.groupsize, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), child.inner_k_tiles, + ) + quantized_linear.weight = q_weight + quantized_linear.scales_and_zeros = scales_and_zeros + else: + _convert_qat_linear_4w(child) + +class Int4WeightOnlyQATLinear(torch.nn.Linear): + """ + This module implements a linear layer with int4 fake quantized grouped + per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, + which uses the efficient int4 tinygemm kernel. + + args: + groupsize: the number of elements in each quantized group for weights + precision: precision of weights + scales_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device = None, + groupsize: int = 256, + inner_k_tiles: int = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device=device, + dtype=precision, + ) + assert not bias, "require bias=False" + assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" + if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): + raise ValueError("Padding for QAT 4w is not supported yet") + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.precision = precision + self.scales_precision = scales_precision + self._fake_quant_enabled = True + + def enable_fake_quant(self, enabled: bool = True): + self._fake_quant_enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n_bit = 4 + qmin = 0 + qmax = 2 ** n_bit - 1 + scales, zero_points = get_groupwise_affine_qparams( + self.weight, n_bit, self.groupsize, self.scales_precision, + ) + w_fq = _Int4WeightOnlyFakeQuantize.apply( + self.weight, scales, zero_points, qmin, qmax, self.groupsize, + ) + return F.linear(x, w_fq) + +def enable_4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.enable_fake_quant() + +def disable_4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.disable_fake_quant() + + # ======================== # | QUANT PRIMITIVES | # ======================== +class _Int4WeightOnlyFakeQuantize(torch.autograd.Function): + """ + Implementation of int4 grouped per channel weight-only fake quantize + intended to match the numerics of the efficient int4 tinygemm kernel. + """ + + @staticmethod + def forward(ctx, input, scales, zero_points, quant_min, quant_max, groupsize): + n_bit = 4 + w_q = groupwise_affine_quantize_tensor_from_qparams( + input, scales, zero_points, n_bit, groupsize, cast_dtypes=False, + ) + w_dq = groupwise_affine_dequantize_tensor_from_qparams( + w_q, scales, zero_points, n_bit, groupsize, cast_dtypes=False, + ) + mask = torch.logical_and((w_q >= quant_min), (w_q <= quant_max)) + ctx.save_for_backward(mask) + return w_dq + + @staticmethod + def backward(ctx, gy): + (mask,) = ctx.saved_tensors + return gy * mask, None, None, None, None, None + class _GenericFakeQuantize(torch.autograd.Function): """ Implementation of generic fake quantize with backward STE. diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index a78c42605..876be13e6 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -174,6 +174,29 @@ def quantize_affine( Output: quantized tensor with requested dtype """ + return _do_quantize_affine( + input, + block_size, + scale, + zero_point, + output_dtype, + quant_min, + quant_max, + zero_point_domain, + cast_dtypes=True, + ) + +def _do_quantize_affine( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + output_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + cast_dtypes: bool = True, +): # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}" @@ -191,7 +214,9 @@ def quantize_affine( if zero_point_domain == ZeroPointDomain.INT: quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max - ).to(output_dtype) + ) + if cast_dtypes: + quant = quant.to(output_dtype) else: assert zero_point_domain == ZeroPointDomain.FLOAT mid_point = (quant_max + quant_min + 1) / 2 @@ -200,7 +225,9 @@ def quantize_affine( torch.clamp( torch.round((input - min_val) / scale), quant_min, quant_max) - ).to(output_dtype) + ) + if cast_dtypes: + quant = quant.to(output_dtype) quant = quant.view(original_shape) return quant @@ -238,11 +265,37 @@ def dequantize_affine( Output: dequantized Tensor, with requested dtype or fp32 """ - + return _do_dequantize_affine( + input, + block_size, + scale, + zero_point, + input_dtype, + quant_min, + quant_max, + zero_point_domain, + output_dtype=output_dtype, + cast_dtypes=True, + ) + +def _do_dequantize_affine( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + input_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + *, + output_dtype: torch.dtype = torch.float32, + cast_dtypes: bool = True, +): # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size - assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}" - assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}" + if cast_dtypes: + assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}" + assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}" quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) @@ -256,24 +309,34 @@ def dequantize_affine( zero_point = zero_point.view(shape_after_reduction) if zero_point_domain == ZeroPointDomain.INT: - # Force a copy to avoid input modification due - # to upcoming in-place operations. - dequant = input.to(torch.int32, copy=True) - if zero_point is not None: - dequant -= zero_point.to(torch.int32) - dequant = dequant.to(output_dtype) - dequant *= scale + if cast_dtypes: + # Force a copy to avoid input modification due + # to upcoming in-place operations. + dequant = input.to(torch.int32, copy=True) + if zero_point is not None: + dequant -= zero_point.to(torch.int32) + dequant = dequant.to(output_dtype) + dequant *= scale + else: + dequant = input.clone() + if zero_point is not None: + dequant -= zero_point + dequant *= scale else: assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}" mid_point = (quant_max + quant_min + 1) / 2 # This should allocate new memory and avoid input modification dequant = input - mid_point - dequant = dequant.to(output_dtype) + if cast_dtypes: + dequant = dequant.to(output_dtype) dequant *= scale if zero_point is not None: dequant += zero_point - return dequant.view(original_shape).to(output_dtype) + dequant = dequant.view(original_shape) + if cast_dtypes: + dequant = dequant.to(output_dtype) + return dequant def choose_qparams_affine( input: torch.Tensor, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 3e3943c93..a3205b8fb 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -10,6 +10,8 @@ import torch.nn.utils.parametrize as parametrize from torchao.utils import find_multiple from .quant_primitives import ( + _do_quantize_affine, + _do_dequantize_affine, MappingType, ZeroPointDomain, choose_qparams_affine, @@ -333,6 +335,7 @@ def groupwise_affine_quantize_tensor_from_qparams( zeros, n_bit=4, groupsize=128, + cast_dtypes=True, ): assert groupsize > 1 # needed for GPTQ single column quantize @@ -347,7 +350,7 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_min = 0 quant_max = 2 ** n_bit - 1 - return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) + return _do_quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, cast_dtypes=cast_dtypes) def groupwise_affine_dequantize_tensor_from_qparams( w_int4x8, @@ -355,6 +358,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( zeros, n_bit=4, groupsize=128, + cast_dtypes=True, ): assert groupsize > 1 # needed for GPTQ single column dequantize @@ -367,7 +371,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( input_dtype = torch.int32 quant_min = 0 quant_max = 2**n_bit - 1 - return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype) + return _do_dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype, cast_dtypes=cast_dtypes) def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):