From 7baa4f7aab7aa3716e6e4cf0e797fcecf363f219 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Wed, 27 Sep 2023 23:40:27 +0900 Subject: [PATCH 01/16] [add] gpt-neox support --- fine-tune.py | 18 ++++- gptneox_attn_replace.py | 169 ++++++++++++++++++++++++++++++++++++++++ supervised-fine-tune.py | 16 +++- 3 files changed, 198 insertions(+), 5 deletions(-) create mode 100644 gptneox_attn_replace.py diff --git a/fine-tune.py b/fine-tune.py index 2bbc82ce..a9a3e72b 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -24,6 +24,7 @@ from torch.utils.data import Dataset from transformers import Trainer, DataCollatorForLanguageModeling from llama_attn_replace import replace_llama_attn +from gptneox_attn_replace import replace_gpt_neox_attn from peft import LoraConfig, get_peft_model from torch.distributed import barrier @@ -39,7 +40,8 @@ @dataclass class ModelArguments: - model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + model_name_or_path: Optional[str] = field(default="EleutherAI/gpt-neox-20b") + model_type: Optional[str] = field(default="gpt-neox") @dataclass class TrainingArguments(transformers.TrainingArguments): @@ -99,7 +101,11 @@ def train(): parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) model_args, training_args = parser.parse_args_into_dataclasses() - replace_llama_attn(training_args.use_flash_attn) + # NOTE: May expand supported model types in the future + if model_args.model_type == "gpt-neox": + replace_gpt_neox_attn(training_args.use_flash_attn) + else: + replace_llama_attn(training_args.use_flash_attn) # Set RoPE scaling factor config = transformers.AutoConfig.from_pretrained( @@ -157,10 +163,16 @@ def train(): data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) if training_args.low_rank_training: + if model_args.model_type == "gpt-neox": + # added `dense` to match with llama as the basic LoRA would only target 'query_key_value' + targets = ["query_key_value", "dense"] + else: + targets=["q_proj", "k_proj", "v_proj", "o_proj"], + config = LoraConfig( r=8, lora_alpha=16, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + target_modules=targets, lora_dropout=0, bias="none", task_type="CAUSAL_LM", diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py new file mode 100644 index 00000000..df9b9a4d --- /dev/null +++ b/gptneox_attn_replace.py @@ -0,0 +1,169 @@ +# Modified based on https://github.com/dvlab-research/LongLoRA + +from typing import Optional, Tuple +import warnings +import torch +import transformers +from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb + +from flash_attn import flash_attn_varlen_func + + +group_size_ratio = 1/4 + + +def _flash_attn(query, key, value, attention_mask=None, head_mask=None): + # Flash attention codes from + # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + + # q, k, v: [bs, nh, seq_len, hd] + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + value_length = value.size(-2) + + # q, k, v: [bs, nh, seq_len, hd] -> [bs, seq_len, nh, hd] -> [bs * seq_len, nh, hd] + query = query.transpose(1, 2).reshape(batch_size * query_length , num_attention_heads, attn_head_size) + key = key.transpose(1, 2).reshape(batch_size * key_length, num_attention_heads, attn_head_size) + value = value.transpose(1, 2).reshape(batch_size * value_length, num_attention_heads, attn_head_size) + + attn_dropout = 0.0 # TODO: attach to config + + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * query_length, + step=query_length, + dtype=torch.int32, + device=query.device, + ) + + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * key_length, + step=key_length, + dtype=torch.int32, + device=key.device, + ) + + attn_output, attn_weights, _ = flash_attn_varlen_func( + query, key, value, cu_seqlens_q, cu_seqlens_k, query_length, value_length, dropout_p=attn_dropout, + softmax_scale=None, causal=True, return_attn_probs=True + ) + + attn_output = attn_output.view(batch_size, query_length, num_attention_heads, attn_head_size).transpose(1, 2) + return attn_output, attn_weights + + +def get_forward_function(use_flash_attn=True, use_full=False): + + def forward_attention( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + head_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + # NOTE: compute SS group size + bsz, q_len, _ = hidden_states.size() + has_layer_past = layer_past is not None + + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.query_key_value(hidden_states) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] + # --> 3 [batch, num_attention_heads, seq_len, head_size] + query = qkv[..., : self.head_size].permute(0, 2, 1, 3) + key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + # [bsz, nh, q_len, hd] + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + query_pass = query[..., self.rotary_ndims :] + key_rot = key[..., : self.rotary_ndims] + key_pass = key[..., self.rotary_ndims :] + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + if has_layer_past: + seq_len += layer_past[0].shape[-2] + cos, sin = self.rotary_emb(value, seq_len=seq_len) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + query = torch.cat((query, query_pass), dim=-1) + key = torch.cat((key, key_pass), dim=-1) + + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) if use_cache else None + + # NOTE: apply shift + if self.training and not use_full: + def shift(qkv, num_heads, head_dim): + # qkv = [bsz, nh, q_len, d] + group_size = int(q_len * group_size_ratio) + if q_len % group_size > 0: + raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size)) + num_group = q_len // group_size + qkv = qkv.transpose(1, 2) + # qkv = [bsz, q_len, nh, d] + qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) + # -> [bsz * n_group, group_s, nh, d) + # -> [bsz * n_group, nh, group_s, d) + qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) + return qkv + + query = shift(query, self.num_attention_heads, self.head_size) + key = shift(key, self.num_attention_heads, self.head_size) + value = shift(value, self.num_attention_heads, self.head_size) + + # Compute attention + if use_flash_attn: + attn_output, attn_weights = _flash_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + # NOTE: shift back + if self.training and not use_full: + attn_output = attn_output.transpose(1, 2) + # [bsz, q_len, nh, hd] + attn_output[:, :, num_heads//2:] = attn_output[:, :, num_heads//2:].roll(group_size//2, dims=1) + attn_output = attn_output.transpose(1, 2) + + # Reshape outputs + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + return forward_attention + + +def replace_gpt_neox_attn(use_flash_attn=True, use_full=False): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if use_flash_attn and cuda_major < 8: + warnings.warn( + "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." + "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" + "Resorting to plain attention..." + ) + use_flash_attn = False + + forward_fn = get_forward_function(use_flash_attn, use_full) + transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention.forward = forward_fn diff --git a/supervised-fine-tune.py b/supervised-fine-tune.py index bf22029a..cecfb9b9 100644 --- a/supervised-fine-tune.py +++ b/supervised-fine-tune.py @@ -27,6 +27,7 @@ from torch.utils.data import Dataset from transformers import Trainer, DataCollatorForLanguageModeling from llama_attn_replace import replace_llama_attn +from gptneox_attn_replace import replace_gpt_neox_attn from peft import LoraConfig, get_peft_model from torch.distributed import barrier @@ -65,6 +66,7 @@ def jload(f, mode="r"): @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + model_type: Optional[str] = field(default="gpt-neox") @dataclass @@ -219,7 +221,11 @@ def train(): parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() - replace_llama_attn(training_args.use_flash_attn, True) + # NOTE: May expand supported model types in the future + if model_args.model_type == "gpt-neox": + replace_gpt_neox_attn(training_args.use_flash_attn) + else: + replace_llama_attn(training_args.use_flash_attn) # Set RoPE scaling factor config = transformers.AutoConfig.from_pretrained( @@ -266,10 +272,16 @@ def train(): data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) if training_args.low_rank_training: + if model_args.model_type == "gpt-neox": + # added `dense` to match with llama as the basic LoRA would only target 'query_key_value' + targets = ["query_key_value", "dense"] + else: + targets=["q_proj", "k_proj", "v_proj", "o_proj"], + config = LoraConfig( r=8, lora_alpha=16, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + target_modules=targets, lora_dropout=0, bias="none", task_type="CAUSAL_LM", From 41977dfdce67fc73a29cacc5a1967520d5c07073 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Wed, 27 Sep 2023 23:40:42 +0900 Subject: [PATCH 02/16] [update] readme --- README.md | 81 ++++++++----------------------------------------------- 1 file changed, 11 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index ff2c503b..59512a95 100644 --- a/README.md +++ b/README.md @@ -1,41 +1,11 @@ -[![Gradio](https://img.shields.io/badge/Gradio-Online%20Demo-blue)](https://2060079530708e861d.gradio.live) - -# LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models - -## News -- [x] [2023.9.22] We release our **13B and 70B 32k models with the supervised fine-tuning**, which is feasible for long context QA. Please check [Llama-2-13b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) and [Llama-2-70b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft). To our best knowledge, **this is the first work that release 70B model with 32k context length**. -- [x] [2023.9.22] We release all our fine-tuned [models](https://huggingface.co/Yukang), including **70B-32k models**, [LLaMA2-LongLoRA-70B-32k](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k), [LLaMA2-LongLoRA-7B-100k](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft). Welcome to check them out! -- [x] [2023.9.22] We release [Paper](http://arxiv.org/abs/2309.12307) and this GitHub repo, including training and evaluation code. - -**LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models [[Paper](http://arxiv.org/abs/2309.12307)]**
-[Yukang Chen](https://scholar.google.com/citations?user=6p0ygKUAAAAJ&hl=en), -[Shengju Qian](https://scholar.google.com/citations?user=QNnWmasAAAAJ), -[Haotian Tang](https://scholar.google.com/citations?user=WxL13BAAAAAJ&hl), -[Xin Lai](https://scholar.google.com/citations?user=tqNDPA4AAAAJ&hl=zh-CN), -[Zhijian Liu](https://scholar.google.com/citations?user=3coYSTUAAAAJ&hl=en), -[Song Han](https://scholar.google.com/citations?user=E0iCaa4AAAAJ&hl=zh-CN), -[Jiaya Jia](https://scholar.google.com/citations?user=XPAkzTEAAAAJ&hl=en)
- -
**Paper** | **Models** | [**Training**](#training) | [**Inference**](#inference) | **Online Demo**
- -

-

-

-

-

-

- -## Abstract -We present LongLoRA, an efficient fine-tuning approach that extends the context sizes of pre-trained large language models (LLMs), with limited computation cost. -Typically, training LLMs with long context sizes is computationally expensive, requiring extensive training hours and GPU resources. -In this paper, we speed up the context extension of LLMs in two aspects. On the one hand, although dense global attention is needed during inference, fine-tuning the model can be effectively and efficiently done by sparse local attention. The proposed shift short attention effectively enables context extension, leading to non-trivial computation saving with similar performance to fine-tuning with vanilla attention. On the other hand, we find that LoRA for context extension works well under the premise of trainable embedding and normalization. LongLoRA demonstrates strong empirical results on various tasks on LLaMA2 models from 7B/13B to 70B. LongLoRA adopts LLaMA2 7B from 4k context to 100k, or LLaMA2 70B to 32k on a single 8x A100 machine. LongLoRA extends models' context while retaining their original architectures, and is compatible with most existing techniques, like FlashAttention-2. In addition, to make LongLoRA practical, we collect a dataset, LongQA, for supervised fine-tuning. It contains more than 3k long context question-answer pairs. For more details, please refer to the [paper](http://arxiv.org/abs/2309.12307). +# LongLoRA (with GPTNeoX support): Efficient Fine-tuning of Long-Context Large Language Models +This repo provides on top of the original implementation, support for GPTNeoX with Flash-Attention and the LongLoRA's shifted short attention as needed. ## Highlights **LongLoRA** speed up the context extension of pre-trained large language models in both attention-level and weight-level. 1. The proposed shifted short attention is easy to implement, compatible with Flash-Attention, and not required during inference. -2. We release all our models, including models from 7B to 70B, context length from 8k to 100k, including [LLaMA2-LongLoRA-7B-100k](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft), [LLaMA2-LongLoRA-13B-64k](https://huggingface.co/Yukang/Llama-2-13b-longlora-64k), and [LLaMA2-LongLoRA-70B-32k](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k). -3. We build up a long-context QA dataset, LongQA, for supervised fine-tuning (SFT). We release 13B and 70B 32k models with SFT, [Llama-2-13b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) and [Llama-2-70b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft). We will further release the dataset in the next month. + ## Installation ``` @@ -43,46 +13,16 @@ pip install -r requirements.txt pip install flash-attn --no-build-isolation ``` -## Released models - -### Models with supervised fine-tuning -| Model | Size | Context | Train | Link | -|:----------------------------------|------|---------|---------|-------------------------------------------------------------------------| -| Llama-2-13b-chat-longlora-32k-sft | 13B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) | -| Llama-2-70b-chat-longlora-32k-sft | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft) | - -### Models with context extension via fully fine-tuning -| Model | Size | Context | Train | Link | -|:----------------------------|------|---------|-------|-------------------------------------------------------------------| -| Llama-2-7b-longlora-8k-ft | 7B | 8192 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-8k-ft) | -| Llama-2-7b-longlora-16k-ft | 7B | 16384 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-16k-ft) | -| Llama-2-7b-longlora-32k-ft | 7B | 32768 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-32k-ft) | -| Llama-2-7b-longlora-100k-ft | 7B | 100000 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft) | -| Llama-2-13b-longlora-8k-ft | 13B | 8192 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-8k-ft) | -| Llama-2-13b-longlora-16k-ft | 13B | 16384 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-16k-ft) | -| Llama-2-13b-longlora-32k-ft | 13B | 32768 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-32k-ft) | - -### Models with context extension via improved LoRA fine-tuning -| Model | Size | Context | Train | Link | -|:----------------------------|------|---------|-------|-------------------------------------------------------------------| -| Llama-2-7b-longlora-8k | 7B | 8192 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-8k) | -| Llama-2-7b-longlora-16k | 7B | 16384 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-16k) | -| Llama-2-7b-longlora-32k | 7B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-32k) | -| Llama-2-13b-longlora-8k | 13B | 8192 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-8k) | -| Llama-2-13b-longlora-16k | 13B | 16384 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-16k) | -| Llama-2-13b-longlora-32k | 13B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-32k) | -| Llama-2-13b-longlora-64k | 13B | 65536 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-64k) | -| Llama-2-70b-longlora-32k | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k) | -| Llama-2-70b-chat-longlora-32k | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k) | - ## Training ### Pre-trained weights -We use LLaMA2 models as the pre-trained weights and fine-tune them to long context window sizes. Please download [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf), [Llama-2-13b-hf](https://huggingface.co/meta-llama/Llama-2-13b-hf), and [Llama-2-70b-hf](https://huggingface.co/meta-llama/Llama-2-70b-hf), based on your choices. +I used GPTNeoX model as the base model architecture, which was ported from the authors' original repo where Llama2 was used. +Some candidate pre-trained weights may include [GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b), [Polyglot-ko-12.8B](https://huggingface.co/EleutherAI/polyglot-ko-12.8b) and other variants. + ### Fine-tuning ``` torchrun --nproc_per_node=8 fine-tune.py \ - --model_name_or_path path_to/Llama-2-7b-hf \ + --model_name_or_path path_to/gpt_neox_model_hf \ --bf16 True \ --output_dir path_to_saving_checkpoints \ --cache_dir path_to_cache \ @@ -107,7 +47,7 @@ torchrun --nproc_per_node=8 fine-tune.py \ --max_steps 1000 ``` -- Please remember to change `path_to/Llama-2-7b-hf`, `path_to_saving_checkpoints`, `path_to_cache` to your own directory. +- Please remember to change `path_to/gpt_neox_model_hf`, `path_to_saving_checkpoints`, `path_to_cache` to your own directory. - Note that you can change `model_max_length` to other values. - You could change `ds_configs/stage2.json` to `ds_configs/stage3.json` if you want. - Please set `use_flash_attn` as `False` if you use V100 machines or do not install flash attention. @@ -143,7 +83,7 @@ torchrun --nproc_per_node=8 supervised-fine-tune.py \ --deepspeed "ds_configs/stage2.json" \ --tf32 True ``` -- We typically make supervised fine-tuning upon the fine-tuned context extended models, `path_to_finetuned_models`, like `Llama-2-13b-longlora-32k` or `Llama-2-13b-longlora-32k-ft`. +- We typically make supervised fine-tuning upon the fine-tuned context extended models, `path_to_finetuned_models` - During our dataset collection, it is hard for us to collect many high-quality QA that are larger than 32768. Thus, if you use our `LongQA.json`, please also set `model_max_length` as 32768. @@ -282,7 +222,8 @@ If you find this project useful in your research, please consider citing: ``` ## Acknowledgement -- This work is built upon the [LLaMA2](https://ai.meta.com/llama) as the pre-trained models. +- This work is an GPTNeoX port of the work from the original authors' code. [LongLoRA](https://github.com/dvlab-research/LongLoRA) +- This work is built upon the [GPTNeoX-HF](https://huggingface.co/docs/transformers/model_doc/gpt_neox) which is based upon [EleutherAI/GPTNeoX](https://github.com/EleutherAI/gpt-neox) as the pre-trained model architecture. - This work is based on [DeepSpeed](https://github.com/microsoft/DeepSpeed), [peft](https://github.com/huggingface/peft), and [Flash-Attention2](https://github.com/Dao-AILab/flash-attention) for acceleration. - Some evaluation code is modified upon [Landmark Attention](https://github.com/epfml/landmark-attention). - We use [LongChat](https://github.com/DachengLi1/LongChat) for the retrieval evaluation. From 9c9d0a2bdbb3c9430f0675d530f8ceb8f4049805 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Thu, 28 Sep 2023 23:17:40 +0900 Subject: [PATCH 03/16] [fix] some of the bugs preventing fine-tune run + There's still bugs in the attention dimensions mismatch --- fine-tune.py | 5 +++-- gptneox_attn_replace.py | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/fine-tune.py b/fine-tune.py index a9a3e72b..185344fd 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -40,7 +40,7 @@ @dataclass class ModelArguments: - model_name_or_path: Optional[str] = field(default="EleutherAI/gpt-neox-20b") + model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped") model_type: Optional[str] = field(default="gpt-neox") @dataclass @@ -123,6 +123,7 @@ def train(): model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, + torch_dtype=torch.bfloat16, ) tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -130,7 +131,7 @@ def train(): cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", - use_fast=False, + use_fast=True, ) special_tokens_dict = dict() diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index df9b9a4d..4670e5dd 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -110,13 +110,13 @@ def forward_attention( present = (key, value) if use_cache else None # NOTE: apply shift + group_size = int(q_len * group_size_ratio) + if q_len % group_size > 0: + raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size)) + num_group = q_len // group_size if self.training and not use_full: def shift(qkv, num_heads, head_dim): # qkv = [bsz, nh, q_len, d] - group_size = int(q_len * group_size_ratio) - if q_len % group_size > 0: - raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size)) - num_group = q_len // group_size qkv = qkv.transpose(1, 2) # qkv = [bsz, q_len, nh, d] qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) @@ -139,7 +139,7 @@ def shift(qkv, num_heads, head_dim): if self.training and not use_full: attn_output = attn_output.transpose(1, 2) # [bsz, q_len, nh, hd] - attn_output[:, :, num_heads//2:] = attn_output[:, :, num_heads//2:].roll(group_size//2, dims=1) + attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) attn_output = attn_output.transpose(1, 2) # Reshape outputs From a5111ef23487b1f93646360805e293fabe3c387d Mon Sep 17 00:00:00 2001 From: naubull2 Date: Fri, 29 Sep 2023 00:00:12 +0900 Subject: [PATCH 04/16] [fix] dimesion discrepancy between attention mask and the query length + group batch attention is skipped to avoid this problem for now --- gptneox_attn_replace.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 4670e5dd..e2a48184 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -120,9 +120,12 @@ def shift(qkv, num_heads, head_dim): qkv = qkv.transpose(1, 2) # qkv = [bsz, q_len, nh, d] qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) + qkv = qkv.transpose(1, 2) + + # TODO: Changing the q_len to group_size, will require attention mask to be adjusted as well # -> [bsz * n_group, group_s, nh, d) # -> [bsz * n_group, nh, group_s, d) - qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) + #qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) return qkv query = shift(query, self.num_attention_heads, self.head_size) From 5862050d5609590cac4134919b459e0157de52f5 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Fri, 29 Sep 2023 00:03:57 +0900 Subject: [PATCH 05/16] [fix] SFT to match the same mods in finetune.py --- supervised-fine-tune.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/supervised-fine-tune.py b/supervised-fine-tune.py index cecfb9b9..fefbc2b2 100644 --- a/supervised-fine-tune.py +++ b/supervised-fine-tune.py @@ -65,7 +65,7 @@ def jload(f, mode="r"): @dataclass class ModelArguments: - model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped") model_type: Optional[str] = field(default="gpt-neox") @@ -243,6 +243,7 @@ def train(): model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, + torch_dtype=torch.bfloat16, ) tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -250,7 +251,7 @@ def train(): cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", - use_fast=False, + use_fast=True, ) special_tokens_dict = dict() @@ -290,6 +291,7 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] + model.config.use_cache = False # required for gradient checkpointing model.enable_input_require_grads() # required for gradient checkpointing model.gradient_checkpointing_enable() # enable gradient checkpointing From 1532c4b3f665110dbecf9077825c6572c0baf69b Mon Sep 17 00:00:00 2001 From: naubull2 Date: Sat, 30 Sep 2023 22:05:28 +0900 Subject: [PATCH 06/16] [add] parallel group attention then reshape back to original form --- gptneox_attn_replace.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index e2a48184..3f398580 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -120,12 +120,12 @@ def shift(qkv, num_heads, head_dim): qkv = qkv.transpose(1, 2) # qkv = [bsz, q_len, nh, d] qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) - qkv = qkv.transpose(1, 2) + #qkv = qkv.transpose(1, 2) # TODO: Changing the q_len to group_size, will require attention mask to be adjusted as well # -> [bsz * n_group, group_s, nh, d) # -> [bsz * n_group, nh, group_s, d) - #qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) + qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) return qkv query = shift(query, self.num_attention_heads, self.head_size) @@ -140,7 +140,9 @@ def shift(qkv, num_heads, head_dim): # NOTE: shift back if self.training and not use_full: - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, num_heads, head_dim) + #attn_output = attn_output.transpose(1, 2) # [bsz, q_len, nh, hd] attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) attn_output = attn_output.transpose(1, 2) From 6fdffbb4680b54fe7beafb10b8face00211217de Mon Sep 17 00:00:00 2001 From: naubull2 Date: Sat, 30 Sep 2023 22:39:39 +0900 Subject: [PATCH 07/16] [fix] non-contiguous dimensions changing view issue --- gptneox_attn_replace.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 3f398580..13531d13 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -128,9 +128,10 @@ def shift(qkv, num_heads, head_dim): qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) return qkv - query = shift(query, self.num_attention_heads, self.head_size) - key = shift(key, self.num_attention_heads, self.head_size) - value = shift(value, self.num_attention_heads, self.head_size) + # contiguous is required as self._attn() will attempt to apply .view() on them + query = shift(query, self.num_attention_heads, self.head_size).contiguous() + key = shift(key, self.num_attention_heads, self.head_size).contiguous() + value = shift(value, self.num_attention_heads, self.head_size).contiguous() # Compute attention if use_flash_attn: @@ -141,7 +142,7 @@ def shift(qkv, num_heads, head_dim): # NOTE: shift back if self.training and not use_full: attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, num_heads, head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size) #attn_output = attn_output.transpose(1, 2) # [bsz, q_len, nh, hd] attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) From fe97f8697601d1f5aae584a179bfd5b70d8f6c64 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Sat, 30 Sep 2023 22:51:58 +0900 Subject: [PATCH 08/16] [add] attention mask to align with the grouped batching --- gptneox_attn_replace.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 13531d13..37a2c405 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -132,6 +132,7 @@ def shift(qkv, num_heads, head_dim): query = shift(query, self.num_attention_heads, self.head_size).contiguous() key = shift(key, self.num_attention_heads, self.head_size).contiguous() value = shift(value, self.num_attention_heads, self.head_size).contiguous() + attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) # Compute attention if use_flash_attn: From 9e30a157f92f9514830a50c42a5627ebea82e6d8 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Mon, 2 Oct 2023 19:47:11 +0900 Subject: [PATCH 09/16] [add] torch autocast for flash attention safety + flash attention only supports in fp16/bf16 --- fine-tune.py | 19 ++++++++++--------- supervised-fine-tune.py | 11 ++++++----- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/fine-tune.py b/fine-tune.py index 185344fd..e05d81dc 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -182,15 +182,16 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] - model.enable_input_require_grads() # required for gradient checkpointing - model.gradient_checkpointing_enable() # enable gradient checkpointing - - trainer = Trainer( - model=model, tokenizer=tokenizer, args=training_args, - train_dataset=dataset["train"], - eval_dataset=None, - data_collator=data_collator) - trainer.train() + with torch.cuda.amp.autocast(dtype=model.dtype): + model.config.use_cache = False # required for gradient checkpointing + model.enable_input_require_grads() # required for gradient checkpointing + model.gradient_checkpointing_enable() # enable gradient checkpointing + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, + train_dataset=dataset["train"], + eval_dataset=None, + data_collator=data_collator) + trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) diff --git a/supervised-fine-tune.py b/supervised-fine-tune.py index fefbc2b2..1f19db73 100644 --- a/supervised-fine-tune.py +++ b/supervised-fine-tune.py @@ -291,12 +291,13 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] - model.config.use_cache = False # required for gradient checkpointing - model.enable_input_require_grads() # required for gradient checkpointing - model.gradient_checkpointing_enable() # enable gradient checkpointing + with torch.cuda.amp.autocast(dtype=model.dtype): + model.config.use_cache = False # required for gradient checkpointing + model.enable_input_require_grads() # required for gradient checkpointing + model.gradient_checkpointing_enable() # enable gradient checkpointing - trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) - trainer.train() + trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) + trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) From 3f9c47c9a36edbbbc1b22bbad19a239c3d0a938b Mon Sep 17 00:00:00 2001 From: naubull2 Date: Mon, 2 Oct 2023 19:47:48 +0900 Subject: [PATCH 10/16] [fix] HF built-in rotary embedding is not compatible with flash-attention + cos/sin cache tensor is not trained parameter, so it's not autocast along with other model parameters through `torch_dtype`. --- gptneox_attn_replace.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 37a2c405..1ffea54e 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -4,13 +4,20 @@ import warnings import torch import transformers -from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb from flash_attn import flash_attn_varlen_func group_size_ratio = 1/4 +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1).to(q.dtype), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1).to(k.dtype), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed def _flash_attn(query, key, value, attention_mask=None, head_mask=None): # Flash attention codes from From b21e9491da3b8809ba91112dc0c70cefbc6b9f4c Mon Sep 17 00:00:00 2001 From: naubull2 Date: Mon, 2 Oct 2023 20:06:01 +0900 Subject: [PATCH 11/16] [add] missing local reference for rotate_half --- gptneox_attn_replace.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 1ffea54e..73b44c42 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -10,6 +10,12 @@ group_size_ratio = 1/4 +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) From b22427371e751718c16e4d1597747b563f135830 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Mon, 2 Oct 2023 21:21:36 +0900 Subject: [PATCH 12/16] [rollback] torch.cuda autocast causes half precision error + Works fine without the torch.cuda autocast context, so rollback. --- fine-tune.py | 19 +++++++++---------- supervised-fine-tune.py | 11 +++++------ 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/fine-tune.py b/fine-tune.py index e05d81dc..9d2e37fc 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -182,16 +182,15 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] - with torch.cuda.amp.autocast(dtype=model.dtype): - model.config.use_cache = False # required for gradient checkpointing - model.enable_input_require_grads() # required for gradient checkpointing - model.gradient_checkpointing_enable() # enable gradient checkpointing - trainer = Trainer( - model=model, tokenizer=tokenizer, args=training_args, - train_dataset=dataset["train"], - eval_dataset=None, - data_collator=data_collator) - trainer.train() + model.config.use_cache = False # required for gradient checkpointing + model.enable_input_require_grads() # required for gradient checkpointing + model.gradient_checkpointing_enable() # enable gradient checkpointing + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, + train_dataset=dataset["train"], + eval_dataset=None, + data_collator=data_collator) + trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) diff --git a/supervised-fine-tune.py b/supervised-fine-tune.py index 1f19db73..fefbc2b2 100644 --- a/supervised-fine-tune.py +++ b/supervised-fine-tune.py @@ -291,13 +291,12 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] - with torch.cuda.amp.autocast(dtype=model.dtype): - model.config.use_cache = False # required for gradient checkpointing - model.enable_input_require_grads() # required for gradient checkpointing - model.gradient_checkpointing_enable() # enable gradient checkpointing + model.config.use_cache = False # required for gradient checkpointing + model.enable_input_require_grads() # required for gradient checkpointing + model.gradient_checkpointing_enable() # enable gradient checkpointing - trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) - trainer.train() + trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) + trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) From 9123e42229ccef749b15dbba65d35feac0d013cc Mon Sep 17 00:00:00 2001 From: naubull2 Date: Tue, 3 Oct 2023 17:35:10 +0900 Subject: [PATCH 13/16] [fix] flash attention causing in-place operation runtime errors --- gptneox_attn_replace.py | 57 +++++++++++++---------------------------- 1 file changed, 18 insertions(+), 39 deletions(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 73b44c42..94dc5a9c 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -5,7 +5,9 @@ import torch import transformers -from flash_attn import flash_attn_varlen_func +from einops import rearrange +from flash_attn import flash_attn_varlen_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input group_size_ratio = 1/4 @@ -25,45 +27,22 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + def _flash_attn(query, key, value, attention_mask=None, head_mask=None): - # Flash attention codes from - # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py - - # q, k, v: [bs, nh, seq_len, hd] - batch_size, num_attention_heads, query_length, attn_head_size = query.size() - key_length = key.size(-2) - value_length = value.size(-2) - - # q, k, v: [bs, nh, seq_len, hd] -> [bs, seq_len, nh, hd] -> [bs * seq_len, nh, hd] - query = query.transpose(1, 2).reshape(batch_size * query_length , num_attention_heads, attn_head_size) - key = key.transpose(1, 2).reshape(batch_size * key_length, num_attention_heads, attn_head_size) - value = value.transpose(1, 2).reshape(batch_size * value_length, num_attention_heads, attn_head_size) - - attn_dropout = 0.0 # TODO: attach to config - - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * query_length, - step=query_length, - dtype=torch.int32, - device=query.device, - ) - - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * key_length, - step=key_length, - dtype=torch.int32, - device=key.device, - ) - - attn_output, attn_weights, _ = flash_attn_varlen_func( - query, key, value, cu_seqlens_q, cu_seqlens_k, query_length, value_length, dropout_p=attn_dropout, - softmax_scale=None, causal=True, return_attn_probs=True - ) - - attn_output = attn_output.view(batch_size, query_length, num_attention_heads, attn_head_size).transpose(1, 2) - return attn_output, attn_weights + # transform the data into the qkv packed form + qkv = torch.stack( + [query, key, value], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + bsz, q_len = qkv.shape[:2] + + qkv = rearrange(qkv, "b s ... -> (b s) ...") + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) + output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, q_len, 0.0, softmax_scale=None, causal=True) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + + # disable attn weights by returning None when using flash attention + return output, None def get_forward_function(use_flash_attn=True, use_full=False): From 7203de200447c5e26f08306e0c84c3afa2f0d948 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Tue, 3 Oct 2023 17:36:42 +0900 Subject: [PATCH 14/16] [fix] mixed use of tabs and spaces --- gptneox_attn_replace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index 94dc5a9c..fda6a221 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -41,8 +41,8 @@ def _flash_attn(query, key, value, attention_mask=None, head_mask=None): output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, q_len, 0.0, softmax_scale=None, causal=True) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - # disable attn weights by returning None when using flash attention - return output, None + # disable attn weights by returning None when using flash attention + return output, None def get_forward_function(use_flash_attn=True, use_full=False): From 8a11ef871a2fb3c9dd6b2b611ee26a7f5ed64009 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Tue, 3 Oct 2023 17:45:21 +0900 Subject: [PATCH 15/16] [change] readme back to where it came from the original repo --- README.md | 81 +++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 70 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 59512a95..ff2c503b 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,41 @@ -# LongLoRA (with GPTNeoX support): Efficient Fine-tuning of Long-Context Large Language Models +[![Gradio](https://img.shields.io/badge/Gradio-Online%20Demo-blue)](https://2060079530708e861d.gradio.live) + +# LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models + +## News +- [x] [2023.9.22] We release our **13B and 70B 32k models with the supervised fine-tuning**, which is feasible for long context QA. Please check [Llama-2-13b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) and [Llama-2-70b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft). To our best knowledge, **this is the first work that release 70B model with 32k context length**. +- [x] [2023.9.22] We release all our fine-tuned [models](https://huggingface.co/Yukang), including **70B-32k models**, [LLaMA2-LongLoRA-70B-32k](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k), [LLaMA2-LongLoRA-7B-100k](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft). Welcome to check them out! +- [x] [2023.9.22] We release [Paper](http://arxiv.org/abs/2309.12307) and this GitHub repo, including training and evaluation code. + +**LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models [[Paper](http://arxiv.org/abs/2309.12307)]**
+[Yukang Chen](https://scholar.google.com/citations?user=6p0ygKUAAAAJ&hl=en), +[Shengju Qian](https://scholar.google.com/citations?user=QNnWmasAAAAJ), +[Haotian Tang](https://scholar.google.com/citations?user=WxL13BAAAAAJ&hl), +[Xin Lai](https://scholar.google.com/citations?user=tqNDPA4AAAAJ&hl=zh-CN), +[Zhijian Liu](https://scholar.google.com/citations?user=3coYSTUAAAAJ&hl=en), +[Song Han](https://scholar.google.com/citations?user=E0iCaa4AAAAJ&hl=zh-CN), +[Jiaya Jia](https://scholar.google.com/citations?user=XPAkzTEAAAAJ&hl=en)
+ +
**Paper** | **Models** | [**Training**](#training) | [**Inference**](#inference) | **Online Demo**
+ +

+

+

+

+

+

+ +## Abstract +We present LongLoRA, an efficient fine-tuning approach that extends the context sizes of pre-trained large language models (LLMs), with limited computation cost. +Typically, training LLMs with long context sizes is computationally expensive, requiring extensive training hours and GPU resources. +In this paper, we speed up the context extension of LLMs in two aspects. On the one hand, although dense global attention is needed during inference, fine-tuning the model can be effectively and efficiently done by sparse local attention. The proposed shift short attention effectively enables context extension, leading to non-trivial computation saving with similar performance to fine-tuning with vanilla attention. On the other hand, we find that LoRA for context extension works well under the premise of trainable embedding and normalization. LongLoRA demonstrates strong empirical results on various tasks on LLaMA2 models from 7B/13B to 70B. LongLoRA adopts LLaMA2 7B from 4k context to 100k, or LLaMA2 70B to 32k on a single 8x A100 machine. LongLoRA extends models' context while retaining their original architectures, and is compatible with most existing techniques, like FlashAttention-2. In addition, to make LongLoRA practical, we collect a dataset, LongQA, for supervised fine-tuning. It contains more than 3k long context question-answer pairs. For more details, please refer to the [paper](http://arxiv.org/abs/2309.12307). -This repo provides on top of the original implementation, support for GPTNeoX with Flash-Attention and the LongLoRA's shifted short attention as needed. ## Highlights **LongLoRA** speed up the context extension of pre-trained large language models in both attention-level and weight-level. 1. The proposed shifted short attention is easy to implement, compatible with Flash-Attention, and not required during inference. - +2. We release all our models, including models from 7B to 70B, context length from 8k to 100k, including [LLaMA2-LongLoRA-7B-100k](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft), [LLaMA2-LongLoRA-13B-64k](https://huggingface.co/Yukang/Llama-2-13b-longlora-64k), and [LLaMA2-LongLoRA-70B-32k](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k). +3. We build up a long-context QA dataset, LongQA, for supervised fine-tuning (SFT). We release 13B and 70B 32k models with SFT, [Llama-2-13b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) and [Llama-2-70b-chat-longlora-32k-sft](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft). We will further release the dataset in the next month. ## Installation ``` @@ -13,16 +43,46 @@ pip install -r requirements.txt pip install flash-attn --no-build-isolation ``` +## Released models + +### Models with supervised fine-tuning +| Model | Size | Context | Train | Link | +|:----------------------------------|------|---------|---------|-------------------------------------------------------------------------| +| Llama-2-13b-chat-longlora-32k-sft | 13B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-chat-longlora-32k-sft) | +| Llama-2-70b-chat-longlora-32k-sft | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k-sft) | + +### Models with context extension via fully fine-tuning +| Model | Size | Context | Train | Link | +|:----------------------------|------|---------|-------|-------------------------------------------------------------------| +| Llama-2-7b-longlora-8k-ft | 7B | 8192 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-8k-ft) | +| Llama-2-7b-longlora-16k-ft | 7B | 16384 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-16k-ft) | +| Llama-2-7b-longlora-32k-ft | 7B | 32768 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-32k-ft) | +| Llama-2-7b-longlora-100k-ft | 7B | 100000 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-100k-ft) | +| Llama-2-13b-longlora-8k-ft | 13B | 8192 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-8k-ft) | +| Llama-2-13b-longlora-16k-ft | 13B | 16384 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-16k-ft) | +| Llama-2-13b-longlora-32k-ft | 13B | 32768 | Full FT | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-32k-ft) | + +### Models with context extension via improved LoRA fine-tuning +| Model | Size | Context | Train | Link | +|:----------------------------|------|---------|-------|-------------------------------------------------------------------| +| Llama-2-7b-longlora-8k | 7B | 8192 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-8k) | +| Llama-2-7b-longlora-16k | 7B | 16384 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-16k) | +| Llama-2-7b-longlora-32k | 7B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-7b-longlora-32k) | +| Llama-2-13b-longlora-8k | 13B | 8192 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-8k) | +| Llama-2-13b-longlora-16k | 13B | 16384 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-16k) | +| Llama-2-13b-longlora-32k | 13B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-32k) | +| Llama-2-13b-longlora-64k | 13B | 65536 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-13b-longlora-64k) | +| Llama-2-70b-longlora-32k | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-longlora-32k) | +| Llama-2-70b-chat-longlora-32k | 70B | 32768 | LoRA+ | [link](https://huggingface.co/Yukang/Llama-2-70b-chat-longlora-32k) | + ## Training ### Pre-trained weights -I used GPTNeoX model as the base model architecture, which was ported from the authors' original repo where Llama2 was used. -Some candidate pre-trained weights may include [GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b), [Polyglot-ko-12.8B](https://huggingface.co/EleutherAI/polyglot-ko-12.8b) and other variants. - +We use LLaMA2 models as the pre-trained weights and fine-tune them to long context window sizes. Please download [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf), [Llama-2-13b-hf](https://huggingface.co/meta-llama/Llama-2-13b-hf), and [Llama-2-70b-hf](https://huggingface.co/meta-llama/Llama-2-70b-hf), based on your choices. ### Fine-tuning ``` torchrun --nproc_per_node=8 fine-tune.py \ - --model_name_or_path path_to/gpt_neox_model_hf \ + --model_name_or_path path_to/Llama-2-7b-hf \ --bf16 True \ --output_dir path_to_saving_checkpoints \ --cache_dir path_to_cache \ @@ -47,7 +107,7 @@ torchrun --nproc_per_node=8 fine-tune.py \ --max_steps 1000 ``` -- Please remember to change `path_to/gpt_neox_model_hf`, `path_to_saving_checkpoints`, `path_to_cache` to your own directory. +- Please remember to change `path_to/Llama-2-7b-hf`, `path_to_saving_checkpoints`, `path_to_cache` to your own directory. - Note that you can change `model_max_length` to other values. - You could change `ds_configs/stage2.json` to `ds_configs/stage3.json` if you want. - Please set `use_flash_attn` as `False` if you use V100 machines or do not install flash attention. @@ -83,7 +143,7 @@ torchrun --nproc_per_node=8 supervised-fine-tune.py \ --deepspeed "ds_configs/stage2.json" \ --tf32 True ``` -- We typically make supervised fine-tuning upon the fine-tuned context extended models, `path_to_finetuned_models` +- We typically make supervised fine-tuning upon the fine-tuned context extended models, `path_to_finetuned_models`, like `Llama-2-13b-longlora-32k` or `Llama-2-13b-longlora-32k-ft`. - During our dataset collection, it is hard for us to collect many high-quality QA that are larger than 32768. Thus, if you use our `LongQA.json`, please also set `model_max_length` as 32768. @@ -222,8 +282,7 @@ If you find this project useful in your research, please consider citing: ``` ## Acknowledgement -- This work is an GPTNeoX port of the work from the original authors' code. [LongLoRA](https://github.com/dvlab-research/LongLoRA) -- This work is built upon the [GPTNeoX-HF](https://huggingface.co/docs/transformers/model_doc/gpt_neox) which is based upon [EleutherAI/GPTNeoX](https://github.com/EleutherAI/gpt-neox) as the pre-trained model architecture. +- This work is built upon the [LLaMA2](https://ai.meta.com/llama) as the pre-trained models. - This work is based on [DeepSpeed](https://github.com/microsoft/DeepSpeed), [peft](https://github.com/huggingface/peft), and [Flash-Attention2](https://github.com/Dao-AILab/flash-attention) for acceleration. - Some evaluation code is modified upon [Landmark Attention](https://github.com/epfml/landmark-attention). - We use [LongChat](https://github.com/DachengLi1/LongChat) for the retrieval evaluation. From 02e4c1cd3fa748c3a630ffcde3bcd284b9071fd3 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Tue, 3 Oct 2023 17:47:40 +0900 Subject: [PATCH 16/16] [remove] unused comments --- gptneox_attn_replace.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index fda6a221..ee16bb0b 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -112,9 +112,7 @@ def shift(qkv, num_heads, head_dim): qkv = qkv.transpose(1, 2) # qkv = [bsz, q_len, nh, d] qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) - #qkv = qkv.transpose(1, 2) - # TODO: Changing the q_len to group_size, will require attention mask to be adjusted as well # -> [bsz * n_group, group_s, nh, d) # -> [bsz * n_group, nh, group_s, d) qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) @@ -124,6 +122,7 @@ def shift(qkv, num_heads, head_dim): query = shift(query, self.num_attention_heads, self.head_size).contiguous() key = shift(key, self.num_attention_heads, self.head_size).contiguous() value = shift(value, self.num_attention_heads, self.head_size).contiguous() + attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) # Compute attention @@ -136,7 +135,6 @@ def shift(qkv, num_heads, head_dim): if self.training and not use_full: attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size) - #attn_output = attn_output.transpose(1, 2) # [bsz, q_len, nh, hd] attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) attn_output = attn_output.transpose(1, 2)