diff --git a/recipes/quantize.py b/recipes/quantize.py index 3b90b62fe..1b0f65d0c 100644 --- a/recipes/quantize.py +++ b/recipes/quantize.py @@ -47,7 +47,7 @@ class QuantizationRecipe: multiple of groupsize. `percdamp`: GPTQ stablization hyperparameter, recommended to be .01 - 8da4w: + 8da4w (PyTorch 2.3+): torchtune.utils.quantization.Int8DynActInt4WeightQuantizer int8 per token dynamic activation with int4 weight only per axis group quantization Args: diff --git a/torchtune/utils/quantization.py b/torchtune/utils/quantization.py index 5f97dea0f..5486a84b6 100644 --- a/torchtune/utils/quantization.py +++ b/torchtune/utils/quantization.py @@ -11,15 +11,14 @@ apply_weight_only_int8_quant, Int4WeightOnlyGPTQQuantizer, Int4WeightOnlyQuantizer, - Int8DynActInt4WeightQuantizer, Quantizer, ) +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 __all__ = [ "Int4WeightOnlyQuantizer", "Int4WeightOnlyGPTQQuantizer", "Int8WeightOnlyQuantizer", - "Int8DynActInt4WeightQuantizer", "get_quantizer_mode", ] @@ -36,10 +35,16 @@ def quantize( Int4WeightOnlyQuantizer: "4w", Int8WeightOnlyQuantizer: "8w", Int4WeightOnlyGPTQQuantizer: "4w-gptq", - Int8DynActInt4WeightQuantizer: "8da4w", } +if TORCH_VERSION_AFTER_2_3: + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + + __all__.append("Int8DynActInt4WeightQuantizer") + _quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w" + + def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]: """Given a quantizer object, returns a string that specifies the type of quantization e.g. 4w, which means int4 weight only quantization.