diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 7046f90b6..debe49ab1 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -6,23 +6,38 @@ from typing import Callable, Optional -from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ +from torchao.dtypes import TensorCoreTiledLayoutType +from torchao.quantization import ( + int4_weight_only, + int8_dynamic_activation_int4_weight, + quantize_, +) from torchao.quantization.prototype.qat import ( + disable_4w_fake_quant, disable_8da4w_fake_quant, + enable_4w_fake_quant, enable_8da4w_fake_quant, + Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, ) from torchao.quantization.prototype.qat._module_swap_api import ( + disable_4w_fake_quant_module_swap, disable_8da4w_fake_quant_module_swap, + enable_4w_fake_quant_module_swap, enable_8da4w_fake_quant_module_swap, + Int4WeightOnlyQATQuantizerModuleSwap, Int8DynActInt4WeightQATQuantizerModuleSwap, ) __all__ = [ "get_quantizer_mode", + "Int4WeightOnlyQuantizer", + "Int4WeightOnlyQATQuantizer", + "Int4WeightOnlyQATQuantizerModuleSwap", "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightQATQuantizer", + "Int8DynActInt4WeightQATQuantizerModuleSwap", ] @@ -57,14 +72,52 @@ def quantize(self, model): _quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant -# ==================================================== -# int8 dynamic activations + int4 weight module swap | -# ==================================================== +# ================== +# int4 weight only | +# ================== + + +class Int4WeightOnlyQuantizer: + """ + Quantizer for applying int4 per group weight only quantization + to linear layers in the model using the efficient tinygemm kernel. + """ + + def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8): + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + def quantize(self, model): + layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles) + quantize_fn = int4_weight_only(self.groupsize, layout_type) + quantize_(model, quantize_fn) + return model + + +_quantizer_to_mode[Int4WeightOnlyQuantizer] = "4w" +_quantizer_to_mode[Int4WeightOnlyQATQuantizer] = "4w-qat" +_quantizer_mode_to_disable_fake_quant["4w-qat"] = disable_4w_fake_quant +_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant + + +# ============= +# module swap | +# ============= # Note: QAT tensor subclass implementation in torchao only works # with FSDP2 today. For other distribution strategies like DDP and # FSDP1, users will need to fall back to the old module swap flow. -__all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap") + +# int4 weight-only +_quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap" +_quantizer_mode_to_disable_fake_quant[ + "4w-qat-module-swap" +] = disable_4w_fake_quant_module_swap +_quantizer_mode_to_enable_fake_quant[ + "4w-qat-module-swap" +] = enable_4w_fake_quant_module_swap + +# int8 dynamic activations + int4 weight _quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap" _quantizer_mode_to_disable_fake_quant[ "8da4w-qat-module-swap"