Skip to content

Commit

Permalink
Update quantization to use tensor subclasses (pytorch#1403)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewor14 authored Sep 12, 2024
1 parent b0895a7 commit 6b43a1c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 35 deletions.
8 changes: 6 additions & 2 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,15 @@ def _setup_model(
) -> nn.Module:
with training.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)

if self._quantization_mode is not None:
model = self._quantizer.quantize(model)
model = model.to(device=self._device, dtype=self._dtype)

model.load_state_dict(model_state_dict)
for k, v in model_state_dict.items():
model_state_dict[k] = v.to(self._device)
model.load_state_dict(model_state_dict, assign=True)
else:
model.load_state_dict(model_state_dict)

# Put model in eval mode.
# Note: This will not disable the dropout applied in SDPA,
Expand Down
81 changes: 48 additions & 33 deletions torchtune/training/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,23 @@

from typing import Callable, Optional

from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
from torchao.quantization.prototype.qat import (
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.prototype.qat._module_swap_api import (
disable_8da4w_fake_quant_module_swap,
enable_8da4w_fake_quant_module_swap,
Int8DynActInt4WeightQATQuantizerModuleSwap,
)


__all__ = [
"get_quantizer_mode",
"Int8DynActInt4WeightQuantizer",
"Int8DynActInt4WeightQATQuantizer",
]


Expand All @@ -16,47 +31,47 @@
_quantizer_mode_to_enable_fake_quant = {}


from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
# ========================================================
# int8 dynamic activations + int4 weight tensor subclass |
# ========================================================

__all__.append("Int8DynActInt4WeightQuantizer")
_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w"

class Int8DynActInt4WeightQuantizer:
"""
Quantizer for applying int8 per token dynamic activation + int4
per group weight quantization to linear layers in the model.
"""

def __init__(self, groupsize: int = 256):
self.groupsize = groupsize

def quantize(self, model):
quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize)
quantize_(model, quantize_fn)
return model

from torchao.quantization.prototype.qat import (
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
Int8DynActInt4WeightQATQuantizer,
)

__all__.append("Int8DynActInt4WeightQATQuantizer")
_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w"
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizer] = "8da4w-qat"
_quantizer_mode_to_disable_fake_quant["8da4w-qat"] = disable_8da4w_fake_quant
_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant

try:
# Note: QAT tensor subclass implementation in torchao only works
# with FSDP2 today. For other distribution strategies like DDP and
# FSDP1, users will need to fall back to the old module swap flow.
# TODO: remove this try catch once we upgrade to torchao 0.5.0

from torchao.quantization.prototype.qat._module_swap_api import (
disable_8da4w_fake_quant_module_swap,
enable_8da4w_fake_quant_module_swap,
Int8DynActInt4WeightQATQuantizerModuleSwap,
)

__all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap")
_quantizer_to_mode[
Int8DynActInt4WeightQATQuantizerModuleSwap
] = "8da4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"8da4w-qat-module-swap"
] = disable_8da4w_fake_quant_module_swap
_quantizer_mode_to_enable_fake_quant[
"8da4w-qat-module-swap"
] = enable_8da4w_fake_quant_module_swap
except ImportError:
pass

# ====================================================
# int8 dynamic activations + int4 weight module swap |
# ====================================================

# Note: QAT tensor subclass implementation in torchao only works
# with FSDP2 today. For other distribution strategies like DDP and
# FSDP1, users will need to fall back to the old module swap flow.
__all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap")
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"8da4w-qat-module-swap"
] = disable_8da4w_fake_quant_module_swap
_quantizer_mode_to_enable_fake_quant[
"8da4w-qat-module-swap"
] = enable_8da4w_fake_quant_module_swap


def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]:
Expand Down

0 comments on commit 6b43a1c

Please sign in to comment.