Skip to content

Commit

Permalink
Add int4 weight-only QAT flow targeting tinygemm kernel
Browse files Browse the repository at this point in the history
Summary: This commit adds an int4 weight-only QAT flow targeting
the efficient tinygemm kernel. This means during fine-tuning
we only simulate numerics of the kernel in bf16, but we only
actually call the kernel after quantizing the model. For more
detail, see pytorch/ao#383.

Test Plan:

Fine-tune QAT command:
```
tune run --nnodes 1 --nproc_per_node 6 --rdzv_endpoint="localhost:8900" qat_distributed --config llama3/8B_qat_full \
    batch_size=8 \
    fake_quant_after_n_steps=1000 \
    checkpointer.output_dir="/tmp/qat_results" \
    quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQATQuantizer \
    quantizer.groupsize=128
```

Quantize command:
```
tune run quantize --config recipes/configs/quantization.yaml \
    model._component_=torchtune.models.llama3.llama3_8b \
    quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \
    quantizer.groupsize=128 \
    checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
    checkpointer.checkpoint_dir="/tmp/qat_results" \
    checkpointer.output_dir="/tmp/qat_results" \
    checkpointer.checkpoint_files=[meta_model_2.pt] \
    checkpointer.model_type=LLAMA3
```

Eval command:
```
tune run eleuther_eval --config eleuther_evaluation \
    tasks="[hellaswag, wikitext]" \
    model._component_=torchtune.models.llama3.llama3_8b \
    quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \
    quantizer.groupsize=128 \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir="/tmp/qat_results" \
    checkpointer.output_dir="/tmp/qat_results" \
    checkpointer.checkpoint_files=[meta_model_2-4w.pt] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
```

Evaluation results:
```
|    Tasks     |Version|Filter|n-shot|Metric|Value |   |Stderr|
|--------------|------:|------|-----:|------|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   |0.4806|±  |0.0167|

|    Tasks     |Version|Filter|n-shot|Metric|Value |   |Stderr|
|--------------|------:|------|-----:|------|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   |0.4914|±  |0.0164|

|    Tasks     |Version|Filter|n-shot|Metric|Value |   |Stderr|
|--------------|------:|------|-----:|------|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   |0.4872|±  |0.0167|
```
  • Loading branch information
andrewor14 committed Sep 13, 2024
1 parent 6b43a1c commit c0c4252
Showing 1 changed file with 58 additions and 5 deletions.
63 changes: 58 additions & 5 deletions torchtune/training/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,38 @@

from typing import Callable, Optional

from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
from torchao.dtypes import TensorCoreTiledLayoutType
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int4_weight,
quantize_,
)
from torchao.quantization.prototype.qat import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.prototype.qat._module_swap_api import (
disable_4w_fake_quant_module_swap,
disable_8da4w_fake_quant_module_swap,
enable_4w_fake_quant_module_swap,
enable_8da4w_fake_quant_module_swap,
Int4WeightOnlyQATQuantizerModuleSwap,
Int8DynActInt4WeightQATQuantizerModuleSwap,
)


__all__ = [
"get_quantizer_mode",
"Int4WeightOnlyQuantizer",
"Int4WeightOnlyQATQuantizer",
"Int4WeightOnlyQATQuantizerModuleSwap",
"Int8DynActInt4WeightQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"Int8DynActInt4WeightQATQuantizerModuleSwap",
]


Expand Down Expand Up @@ -57,14 +72,52 @@ def quantize(self, model):
_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant


# ====================================================
# int8 dynamic activations + int4 weight module swap |
# ====================================================
# ==================
# int4 weight only |
# ==================


class Int4WeightOnlyQuantizer:
"""
Quantizer for applying int4 per group weight only quantization
to linear layers in the model using the efficient tinygemm kernel.
"""

def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8):
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles

def quantize(self, model):
layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles)
quantize_fn = int4_weight_only(self.groupsize, layout_type)
quantize_(model, quantize_fn)
return model


_quantizer_to_mode[Int4WeightOnlyQuantizer] = "4w"
_quantizer_to_mode[Int4WeightOnlyQATQuantizer] = "4w-qat"
_quantizer_mode_to_disable_fake_quant["4w-qat"] = disable_4w_fake_quant
_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant


# =============
# module swap |
# =============

# Note: QAT tensor subclass implementation in torchao only works
# with FSDP2 today. For other distribution strategies like DDP and
# FSDP1, users will need to fall back to the old module swap flow.
__all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap")

# int4 weight-only
_quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"4w-qat-module-swap"
] = disable_4w_fake_quant_module_swap
_quantizer_mode_to_enable_fake_quant[
"4w-qat-module-swap"
] = enable_4w_fake_quant_module_swap

# int8 dynamic activations + int4 weight
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"8da4w-qat-module-swap"
Expand Down

0 comments on commit c0c4252

Please sign in to comment.