Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add int4 weight-only QAT flow targeting tinygemm kernel #1570

Merged
merged 1 commit into from
Sep 26, 2024

Commits on Sep 13, 2024

  1. Add int4 weight-only QAT flow targeting tinygemm kernel

    Summary: This commit adds an int4 weight-only QAT flow targeting
    the efficient tinygemm kernel. This means during fine-tuning
    we only simulate numerics of the kernel in bf16, but we only
    actually call the kernel after quantizing the model. For more
    detail, see pytorch/ao#383.
    
    Test Plan:
    
    Fine-tune QAT command:
    ```
    tune run --nnodes 1 --nproc_per_node 6 --rdzv_endpoint="localhost:8900" qat_distributed --config llama3/8B_qat_full \
        batch_size=8 \
        fake_quant_after_n_steps=1000 \
        checkpointer.output_dir="/tmp/qat_results" \
        quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQATQuantizer \
        quantizer.groupsize=128
    ```
    
    Quantize command:
    ```
    tune run quantize --config recipes/configs/quantization.yaml \
        model._component_=torchtune.models.llama3.llama3_8b \
        quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \
        quantizer.groupsize=128 \
        checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
        checkpointer.checkpoint_dir="/tmp/qat_results" \
        checkpointer.output_dir="/tmp/qat_results" \
        checkpointer.checkpoint_files=[meta_model_2.pt] \
        checkpointer.model_type=LLAMA3
    ```
    
    Eval command:
    ```
    tune run eleuther_eval --config eleuther_evaluation \
        tasks="[hellaswag, wikitext]" \
        model._component_=torchtune.models.llama3.llama3_8b \
        quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \
        quantizer.groupsize=128 \
        checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
        checkpointer.checkpoint_dir="/tmp/qat_results" \
        checkpointer.output_dir="/tmp/qat_results" \
        checkpointer.checkpoint_files=[meta_model_2-4w.pt] \
        checkpointer.model_type=LLAMA3 \
        tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
        tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
    ```
    
    Evaluation results:
    ```
    |    Tasks     |Version|Filter|n-shot|Metric|Value |   |Stderr|
    |--------------|------:|------|-----:|------|-----:|---|-----:|
    |truthfulqa_mc2|      2|none  |     0|acc   |0.4806|±  |0.0167|
    
    |    Tasks     |Version|Filter|n-shot|Metric|Value |   |Stderr|
    |--------------|------:|------|-----:|------|-----:|---|-----:|
    |truthfulqa_mc2|      2|none  |     0|acc   |0.4914|±  |0.0164|
    
    |    Tasks     |Version|Filter|n-shot|Metric|Value |   |Stderr|
    |--------------|------:|------|-----:|------|-----:|---|-----:|
    |truthfulqa_mc2|      2|none  |     0|acc   |0.4872|±  |0.0167|
    ```
    andrewor14 committed Sep 13, 2024
    Configuration menu
    Copy the full SHA
    c0c4252 View commit details
    Browse the repository at this point in the history