Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add int4 weight-only QAT flow targeting tinygemm kernel
Summary: This commit adds an int4 weight-only QAT flow targeting the efficient tinygemm kernel. This means during fine-tuning we only simulate numerics of the kernel in bf16, but we only actually call the kernel after quantizing the model. For more detail, see pytorch/ao#383. Test Plan: Fine-tune QAT command: ``` tune run --nnodes 1 --nproc_per_node 6 --rdzv_endpoint="localhost:8900" qat_distributed --config llama3/8B_qat_full \ batch_size=8 \ fake_quant_after_n_steps=1000 \ checkpointer.output_dir="/tmp/qat_results" \ quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQATQuantizer \ quantizer.groupsize=128 ``` Quantize command: ``` tune run quantize --config recipes/configs/quantization.yaml \ model._component_=torchtune.models.llama3.llama3_8b \ quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \ quantizer.groupsize=128 \ checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \ checkpointer.checkpoint_dir="/tmp/qat_results" \ checkpointer.output_dir="/tmp/qat_results" \ checkpointer.checkpoint_files=[meta_model_2.pt] \ checkpointer.model_type=LLAMA3 ``` Eval command: ``` tune run eleuther_eval --config eleuther_evaluation \ tasks="[hellaswag, wikitext]" \ model._component_=torchtune.models.llama3.llama3_8b \ quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \ quantizer.groupsize=128 \ checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir="/tmp/qat_results" \ checkpointer.output_dir="/tmp/qat_results" \ checkpointer.checkpoint_files=[meta_model_2-4w.pt] \ checkpointer.model_type=LLAMA3 \ tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model ``` Evaluation results: ``` | Tasks |Version|Filter|n-shot|Metric|Value | |Stderr| |--------------|------:|------|-----:|------|-----:|---|-----:| |truthfulqa_mc2| 2|none | 0|acc |0.4806|± |0.0167| | Tasks |Version|Filter|n-shot|Metric|Value | |Stderr| |--------------|------:|------|-----:|------|-----:|---|-----:| |truthfulqa_mc2| 2|none | 0|acc |0.4914|± |0.0164| | Tasks |Version|Filter|n-shot|Metric|Value | |Stderr| |--------------|------:|------|-----:|------|-----:|---|-----:| |truthfulqa_mc2| 2|none | 0|acc |0.4872|± |0.0167| ```
- Loading branch information