Summary: This commit adds an int4 weight-only QAT flow targeting
the efficient tinygemm kernel. This means during fine-tuning
we only simulate numerics of the kernel in bf16, but we only
actually call the kernel after quantizing the model. For more
detail, see pytorch/ao#383.
Test Plan:
Fine-tune QAT command:
```
tune run --nnodes 1 --nproc_per_node 6 --rdzv_endpoint="localhost:8900" qat_distributed --config llama3/8B_qat_full \
batch_size=8 \
fake_quant_after_n_steps=1000 \
checkpointer.output_dir="/tmp/qat_results" \
quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQATQuantizer \
quantizer.groupsize=128
```
Quantize command:
```
tune run quantize --config recipes/configs/quantization.yaml \
model._component_=torchtune.models.llama3.llama3_8b \
quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \
quantizer.groupsize=128 \
checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
checkpointer.checkpoint_dir="/tmp/qat_results" \
checkpointer.output_dir="/tmp/qat_results" \
checkpointer.checkpoint_files=[meta_model_2.pt] \
checkpointer.model_type=LLAMA3
```
Eval command:
```
tune run eleuther_eval --config eleuther_evaluation \
tasks="[hellaswag, wikitext]" \
model._component_=torchtune.models.llama3.llama3_8b \
quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \
quantizer.groupsize=128 \
checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir="/tmp/qat_results" \
checkpointer.output_dir="/tmp/qat_results" \
checkpointer.checkpoint_files=[meta_model_2-4w.pt] \
checkpointer.model_type=LLAMA3 \
tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
```
Evaluation results:
```
| Tasks |Version|Filter|n-shot|Metric|Value | |Stderr|
|--------------|------:|------|-----:|------|-----:|---|-----:|
|truthfulqa_mc2| 2|none | 0|acc |0.4806|± |0.0167|
| Tasks |Version|Filter|n-shot|Metric|Value | |Stderr|
|--------------|------:|------|-----:|------|-----:|---|-----:|
|truthfulqa_mc2| 2|none | 0|acc |0.4914|± |0.0164|
| Tasks |Version|Filter|n-shot|Metric|Value | |Stderr|
|--------------|------:|------|-----:|------|-----:|---|-----:|
|truthfulqa_mc2| 2|none | 0|acc |0.4872|± |0.0167|
```