Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add int4 weight-only QAT flow targeting tinygemm kernel #1570

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions torchtune/training/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,38 @@

from typing import Callable, Optional

from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
from torchao.dtypes import TensorCoreTiledLayoutType
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int4_weight,
quantize_,
)
from torchao.quantization.prototype.qat import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.prototype.qat._module_swap_api import (
disable_4w_fake_quant_module_swap,
disable_8da4w_fake_quant_module_swap,
enable_4w_fake_quant_module_swap,
enable_8da4w_fake_quant_module_swap,
Int4WeightOnlyQATQuantizerModuleSwap,
Int8DynActInt4WeightQATQuantizerModuleSwap,
)


__all__ = [
"get_quantizer_mode",
"Int4WeightOnlyQuantizer",
"Int4WeightOnlyQATQuantizer",
"Int4WeightOnlyQATQuantizerModuleSwap",
"Int8DynActInt4WeightQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"Int8DynActInt4WeightQATQuantizerModuleSwap",
]


Expand Down Expand Up @@ -57,14 +72,52 @@ def quantize(self, model):
_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant


# ====================================================
# int8 dynamic activations + int4 weight module swap |
# ====================================================
# ==================
# int4 weight only |
# ==================


class Int4WeightOnlyQuantizer:
"""
Quantizer for applying int4 per group weight only quantization
to linear layers in the model using the efficient tinygemm kernel.
"""

def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8):
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles

def quantize(self, model):
layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles)
quantize_fn = int4_weight_only(self.groupsize, layout_type)
quantize_(model, quantize_fn)
return model


_quantizer_to_mode[Int4WeightOnlyQuantizer] = "4w"
_quantizer_to_mode[Int4WeightOnlyQATQuantizer] = "4w-qat"
_quantizer_mode_to_disable_fake_quant["4w-qat"] = disable_4w_fake_quant
_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant


# =============
# module swap |
# =============

# Note: QAT tensor subclass implementation in torchao only works
# with FSDP2 today. For other distribution strategies like DDP and
# FSDP1, users will need to fall back to the old module swap flow.
__all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap")

# int4 weight-only
_quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"4w-qat-module-swap"
] = disable_4w_fake_quant_module_swap
_quantizer_mode_to_enable_fake_quant[
"4w-qat-module-swap"
] = enable_4w_fake_quant_module_swap

# int8 dynamic activations + int4 weight
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"8da4w-qat-module-swap"
Expand Down
Loading