diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index 4b80e132b..60cb83ff5 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -222,11 +222,15 @@ def _setup_model( ) -> nn.Module: with training.set_default_dtype(self._dtype), self._device: model = config.instantiate(model_cfg) + if self._quantization_mode is not None: model = self._quantizer.quantize(model) model = model.to(device=self._device, dtype=self._dtype) - - model.load_state_dict(model_state_dict) + for k, v in model_state_dict.items(): + model_state_dict[k] = v.to(self._device) + model.load_state_dict(model_state_dict, assign=True) + else: + model.load_state_dict(model_state_dict) # Put model in eval mode. # Note: This will not disable the dropout applied in SDPA, diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 8d734b010..7046f90b6 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -6,8 +6,23 @@ from typing import Callable, Optional +from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ +from torchao.quantization.prototype.qat import ( + disable_8da4w_fake_quant, + enable_8da4w_fake_quant, + Int8DynActInt4WeightQATQuantizer, +) +from torchao.quantization.prototype.qat._module_swap_api import ( + disable_8da4w_fake_quant_module_swap, + enable_8da4w_fake_quant_module_swap, + Int8DynActInt4WeightQATQuantizerModuleSwap, +) + + __all__ = [ "get_quantizer_mode", + "Int8DynActInt4WeightQuantizer", + "Int8DynActInt4WeightQATQuantizer", ] @@ -16,47 +31,47 @@ _quantizer_mode_to_enable_fake_quant = {} -from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer +# ======================================================== +# int8 dynamic activations + int4 weight tensor subclass | +# ======================================================== -__all__.append("Int8DynActInt4WeightQuantizer") -_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w" +class Int8DynActInt4WeightQuantizer: + """ + Quantizer for applying int8 per token dynamic activation + int4 + per group weight quantization to linear layers in the model. + """ + + def __init__(self, groupsize: int = 256): + self.groupsize = groupsize + + def quantize(self, model): + quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize) + quantize_(model, quantize_fn) + return model -from torchao.quantization.prototype.qat import ( - disable_8da4w_fake_quant, - enable_8da4w_fake_quant, - Int8DynActInt4WeightQATQuantizer, -) -__all__.append("Int8DynActInt4WeightQATQuantizer") +_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w" _quantizer_to_mode[Int8DynActInt4WeightQATQuantizer] = "8da4w-qat" _quantizer_mode_to_disable_fake_quant["8da4w-qat"] = disable_8da4w_fake_quant _quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant -try: - # Note: QAT tensor subclass implementation in torchao only works - # with FSDP2 today. For other distribution strategies like DDP and - # FSDP1, users will need to fall back to the old module swap flow. - # TODO: remove this try catch once we upgrade to torchao 0.5.0 - - from torchao.quantization.prototype.qat._module_swap_api import ( - disable_8da4w_fake_quant_module_swap, - enable_8da4w_fake_quant_module_swap, - Int8DynActInt4WeightQATQuantizerModuleSwap, - ) - - __all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap") - _quantizer_to_mode[ - Int8DynActInt4WeightQATQuantizerModuleSwap - ] = "8da4w-qat-module-swap" - _quantizer_mode_to_disable_fake_quant[ - "8da4w-qat-module-swap" - ] = disable_8da4w_fake_quant_module_swap - _quantizer_mode_to_enable_fake_quant[ - "8da4w-qat-module-swap" - ] = enable_8da4w_fake_quant_module_swap -except ImportError: - pass + +# ==================================================== +# int8 dynamic activations + int4 weight module swap | +# ==================================================== + +# Note: QAT tensor subclass implementation in torchao only works +# with FSDP2 today. For other distribution strategies like DDP and +# FSDP1, users will need to fall back to the old module swap flow. +__all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap") +_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap" +_quantizer_mode_to_disable_fake_quant[ + "8da4w-qat-module-swap" +] = disable_8da4w_fake_quant_module_swap +_quantizer_mode_to_enable_fake_quant[ + "8da4w-qat-module-swap" +] = enable_8da4w_fake_quant_module_swap def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]: