Skip to content

PEFT TigerBot 7b with QLoRA, building an domain LLM on one consumer level GPU in hours

i4never edited this page Jun 28, 2023 · 7 revisions

6月7日,我们开源了TigerBot,同时发布了7b-base7b-sft180b-research版本的模型、数据、以及pretrain/sft阶段的训练代码。但是,受限于参数量与计算资源,即使借助deepspeed等框架,微调模型的硬件门槛仍然很高,并且容易带来预训练知识遗忘等问题。

为此,我们开源基于QLoRA,以tigerbot-7b-sft模型为基座,在medical-qa数据上微调的代码。在max_length=1024batch_size=1的配置下,微调可以在1张RTX 3090(24G)上运行。

QLoRA

由于很多下游任务的本征维度并不大(Exploring Universal Intrinsic Task Subspace via Prompt Tuning Chapter 2) ,因此微调小部份参数,模型很可能在下游任务就能获得不错的结果。LoRA借鉴这一想法,选择对模型的所有线性层(attention中的qkv映射、ffn层)做增量微调,并且对增量做了低秩分解的假设。假设原模型的各线性层在激活前的计算为:

$$Y=W_{d*d}X+b$$

lora为该层添加增量权重:

$$Y=W_{d*d}X+b+A_{d*r}B_{r*d}X$$

训练时只微调其中的 $AB$ 矩阵。其中 $r$ 称作 $lora\ rank$ (通常取4、8)。一般来说, $d$ >> $r$ ,相比于整个模型,LoRA的trainable参数非常少。推理时,权重可以被合成为 $Y=(W_{d*d}+A_{d*r}B_{r*d})X+b$ ,相比于原模型没有额外开销。

在LoRA的基础上,QLoRA在训练阶段引入了量化机制,进一步减少了内存需求。

Train

Data

训练数据基于https:/Toyhom/Chinese-medical-dialogue-data。部份数据ask字段中提问较多,answer中没有完全覆盖用户的提问,因此从各个科室的数据中分别采样了1k条ask字段为空的数据作为medical-qa的数据,共6k条,其中4800条用作train,1200条用做eval。

import pandas as pd
import os

paths = ['Andriatria_男科/男科5-13000.csv', 'IM_内科/内科5000-33000.csv', 'OAGD_妇产科/妇产科6-28000.csv', 'Oncology_肿瘤科/肿瘤科5-10000.csv', 'Pediatric_儿科/儿科5-14000.csv', 'Surgical_外科/外科5-14000.csv']

dfs = [pd.read_csv(p, encoding='GB18030') for p in paths]

medical_qa = list()
for df in dfs:
    for _, r in df[df.ask == '无'].sample(n=1000).iterrows():
        medical_qa.append({'instruction': "\n\n### Instruction:\n" + r.title, 'input': None, 'output': "\n\n### Response:\n" + r.answer})
{"instruction": "\n\n### Instruction:\n成年人一旦感染乙肝病毒后是不是都进入急性", "input": null, "output": "\n\n### Response:\n\n\n### Response:\n感染乙肝后,一般情况下并非都是会出现急性经过。相对而言根据发病特点,也可能会存在病情迁延,转化成慢性经过。"}
{"instruction": "\n\n### Instruction:\n风热感冒致鼻塞吃什么药", "input": null, "output": "\n\n### Response:\n你描述的情况可能是风热感冒引起的,可以用双黄连口服液治疗。饮食宜清淡,不能吃辛辣刺激之品,多喝水。可以适当锻炼。"}
{"instruction": "\n\n### Instruction:\n每天晚上左心脏痛,白天就不痛了。", "input": null, "output": "\n\n### Response:\n考虑是有冠心病的情况,可以做一下心电图看看,必要时做一下冠脉造影看看,治疗冠心病可以服用消心痛、立普妥、阿司匹林等,或者舌下含服硝酸甘油治疗。平时一定要注意休息,避免劳累和情绪激动。祝早日康复"}

Train command

完整代码在https:/TigerResearch/TigerBot/train/train_with_qlora.py

得益于Peft库的工程实现,在已有的任意模型上叠加LoRA非常简单。Peft中自动寻找模型线性层,添加adapter,以及合并参数的实现非常漂亮,调用极其简单。https:/artidoro/qlora 额外实现了训练的量化操作,进一步降低了内存需求(需要升级transformers版本)。

训练前需要安装train/requirements_qlora.txt中的依赖(pip install -r requirements_qlora.txt)。

  • 使用QLoRA微调

    CUDA_VISIBLE_DEVICES=0 python ./train_with_qlora.py --model_name_or_path TigerResearch/tigerbot-7b-sft --output_dir ./sft-qlora --data_files ./data/medical_qa_6000.jsonl --do_train --do_eval --num_train_epochs 1 --learning_rate 1e-4 --evaluation_strategy steps --eval_steps 1000 --optim paged_adamw_32bit --bf16 True --lora_r 8 --bits 4 --logging_steps 10 --pad_to_max_length True --per_device_train_batch_size 1 --per_device_eval_batch_size 2 --gradient_accumulation_steps 4
    • 3090*1
  • 全量微调

    deepspeed --include="localhost:0,1,2,3" ./train_with_qlora.py --deepspeed ./ds_config/ds_config_qlora.json --full_finetune True --model_name_or_path TigerResearch/tigerbot-7b-sft --output_dir ./sft-full-finetune --data_files ./medical_qa_6000.jsonl --do_train --do_eval --num_train_epochs 1 --learning_rate 1e-5 --evaluation_strategy steps --eval_steps 1000 --bf16 True --bits 16 --logging_steps 10 --pad_to_max_length True --per_device_train_batch_size 1 --per_device_eval_batch_size 2
    • 3090*4
    • 全量微调开启了zero3、optimizer offload、parameter offload,训练阶段需要额外50G内存
  • 全量微调也也可以使用开源的train_sft,脚本为了公平比较性能,同时减少训练中途长数据导致oom的情况,每条数据都被padding到了max_length(默认为1024)

Merge adapter weights

训练完成后,可以将adapter weights合并到原参数中:

import transformers
from peft import PeftModel

model = transformers.AutoModelForCausalLM.from_pretrained("TigerResearch/tigerbot-7b-sft")
model = PeftModel.from_pretrained(model, './adapter_model', is_trainable=False)
model = model.merge_and_unload()

以原模型推理速度(token/s)为基准,不合并参数会带来 33.234% 的推理速度损失,merge后的推理速度与原模型持平(100.077%)。

Result

peft后的medical-bot模型可以从huggingface上下载。如果你有tigerbot-7b-sft的模型,也可以只下载qlora_ckpt_2400_adapter_model下的adapter参数,与原模型合并后使用。

Metrics

全量微调的模型在2k步时出现了过拟合的现象,终止了后续训练;qlora-sft完整地跑完了2个epoch,没有出现过拟合现象。qlora-sft设置了gradient_accumulation_steps=4,模拟了batch_size=4的情况。

Train loss Eval loss Metric

Model output

模型均以torch.float32加载,生成参数均为:

generation_kwargs = {
  "top_p": 0.95,
  "temperature": 0.8,
  "max_length": 256,
  "early_stopping": True,
}

tigerbot-7b-sft:未经medical-qa数据微调的模型

tigerbot-7b-qlora-sft:tigerbot-7b-sft经过QLoRA微调2.4k step后的模型

Tigerbot-7b-full-sf:tigerbot-7b-sft经过全参数微调2k step后的模型

  • 发高烧怎么办?

    tigerbot-7b-sft (44.196 token/s):
    发高烧是身体在提醒你,你的身体需要更多的水,更多的休息,更多的营养。
    不要怕,不要慌,不要急,不要乱,冷静的对待,慢慢的解决。
    多喝水,多休息,多补充营养
    
    tigerbot-7b-qlora-sft (no merge: 34.748 token/s merge: 47.688 token/s):
    你的这种情况可以先用酒精擦浴退热,然后再查明发热的病因,引起发热的原因很多,最常见的是感染,其次是结缔组织病、恶性肿瘤等。先不急着用药。
    
    tigerbot-7b-full-sft (47.876 token/s):
    你的这种情况可以先用酒精擦浴退热,然后再查明发热的病因,引起发热的原因很多,最常见的是感染,其次是结缔组织病、恶性肿瘤等。先不急着用药。
    
  • 7个月宝宝咳嗽怎么办?

    tigerbot-7b-sft (47.193 token/s):
    病情分析:
    你好,宝宝咳嗽有痰鸣的话,喝盐酸氨溴索或者给宝宝做雾化,平时多给宝宝喝水.
    指导意见:
    宝宝咳嗽有痰鸣的话,喝盐酸氨溴索或者给宝宝做雾化,平时多给宝宝喝水.
    
    tigerbot-7b-qlora-sft (no merge: 34.989 token/s merge: 47.568 token/s):
    考虑是支气管炎引起的,可以服用小儿清肺化痰颗粒,配合阿奇霉素颗粒治疗,注意保暖预防发生感冒的情况,定期做好复查。
    
    tigerbot-7b-full-sft (48.327 token/s):
    考虑是支气管炎引起的,可以给孩子喂服麻甘颗粒,清竹沥口服液试试,如果血象高的话需要配合抗生素治疗的,阿奇霉素分散片效果不错的,让宝宝多喝水,注意室内经常通风换气,慢慢会改善的,祝宝宝健康成长
    
  • 如何缓解化疗后的不良反应?

    tigerbot-7b-sft (46.746 token/s):
    化疗后,患者常常会出现恶心、呕吐、脱发、食欲不振、疲乏无力、睡眠不佳、口干、口腔溃疡、皮肤瘙痒、便秘、腹泻等一系列症状,这些症状统称为化疗后的不良反应。  化疗后不良反应的处理:  1、恶心、呕吐:  化疗后恶心、呕吐是化疗最常见的副作用之一,一般出现在化疗后前三天,多数患者在化疗后一、二周内可自行缓解。  处理:  (1)对症治疗:恶心、呕吐严重者,可给予吗丁啉或胃复安等胃动力药;呕吐频繁者可给予维生素B6或维生素B1治疗。  (2)对病因治疗:如化疗药物引起的恶心、呕吐,可换用其他药物或化疗间歇期,或减量、缩短化疗周期。  (3)对症治疗:恶心、呕吐严重者,可给予吗丁啉或胃复安等胃动力药;呕吐频繁者可给予维生素B6或维生素B1治疗
    
    tigerbot-7b-qlora-sft (no merge: 34.739 token/s merge: 47.927 token/s):
    化疗后出现的恶心、呕吐、食欲不振、脱发、白细胞降低、免疫力下降等不良反应,可以服用中药进行调理,如参苓白术散、补中益气丸等,也可以服用维生素B6、谷维素等药物进行治疗
    
    tigerbot-7b-full-sft (47.110 token/s):
    根据你的情况属于药物反应有关,需要增加营养补充维生素微量元素,易消化易吸收饮食,避免不良刺激,可以服用参苓白术散调理
    
  • 低压高是啥原因? (tigerbot-7b-sft模型没有足够上下文与领域知识,“回显”了问题)

    tigerbot-7b-sft (47.007 token/s):
    低压高是啥原因?
    
    tigerbot-7b-qlora-sft (no merge: 34.621 token/s merge: 47.402 token/s):
    高血压是心脑血管疾病的主要危险因素,高血压患者要定期复查,根据病情调整用药,平时注意低盐低脂饮食,适当运动,保持心情舒畅。
    
    tigerbot-7b-full-sft (47.665 token/s):
    低压高是肾上腺皮质增生症,是良性肿瘤,多发生于青年男性,临床表现为高血压、肾上腺皮质肿大、压迫症状等
    
  • 法国首都是哪里?

    tigerbot-7b-sft (38.726 token/s):
    法国的首都是巴黎。
    
    tigerbot-7b-qlora-sft (no merge: 33.816 token/s merge: 46.806 token/s):
    巴黎。法国首都巴黎,位于法国北部,塞纳河西岸,距河口(英吉利海峡)375公里
    
    tigerbot-7b-full-sft (40.255 token/s): 
    法国的首都是巴黎。
    
  • 中国古代四大美女是谁?

    tigerbot-7b-sft (45.188 token/s):
    中国古代四大美人是:西施、王昭君、貂蝉、杨玉环。
    
    tigerbot-7b-qlora-sft (no merge: 33.512 token/s merge: 44.782 token/s):
    西施、王昭君、貂蝉、杨玉环。
    
    tigerbot-7b-full-sft (44.909 token/s):
    西施,王昭君,貂蝉,杨玉环。
    
  • what are the most famous artwork in Louvre

    tigerbot-7b-sft (34.680 token/s):
    The Mona Lisa
    
    tigerbot-7b-qlora-sft (no merge: 28.519 token/s merge: 46.655 token/s):
    The Mona Lisa (La Gioconda), painted by Leonardo da Vinci in the early 16th century, is the most famous artwork in the Louvre.
    
    tigerbot-7b-full-sft (47.602 token/s):
    The Mona Lisa is the most famous artwork in the Louvre. It is a portrait of Lisa Gherardini, wife of a Florentine merchant, painted by Leonardo da Vinci in the early 16th century.
    
  • tell me about bitcoin

    tigerbot-7b-sft (46.367 token/s):
    Bitcoin is a digital currency or cryptocurrency created in 2009. It is the first and most widely used cryptocurrency. Bitcoin is decentralized, meaning it is not controlled by any government or central bank. It is also transparent, as all transactions are recorded in a public ledger known as the blockchain. Bitcoin is a form of digital money that can be exchanged online and used to make purchases. It is not a physical currency, but rather a digital asset that is stored in a digital wallet. Bitcoin is not backed by any physical assets, but rather by a network of computers that maintain a public ledger of all bitcoin transactions. Bitcoin is a form of digital money that can be exchanged online and used to make purchases.
    
    tigerbot-7b-qlora-sft (no merge: 33.795 token/s merge: 48.214 token/s):
    Bitcoin is a digital currency or cryptocurrency that was launched in 2009. It is a decentralized, global, and public ledger that records all bitcoin transactions. Bitcoin is not controlled by any central bank or government and is not backed by any physical assets. It is instead powered by a complex network of computers that maintain the ledger and verify transactions. Bitcoin is used to make purchases online and can be exchanged for other currencies.
    
    tigerbot-7b-full-sft (48.401 token/s):
    Bitcoin is a cryptocurrency that was launched in 2009. It is decentralized, meaning it is not controlled by any government or central bank. Bitcoin is powered by a public ledger called the blockchain, which is maintained by a network of computers. Transactions made on the blockchain are irreversible, meaning once a transaction is made, it cannot be reversed. Bitcoin is also highly secure, as it uses cryptographic algorithms to verify and secure transactions.
    

reference

https://huggingface.co/TigerResearch

https:/TigerResearch/TigerBot/train/train_with_qlora.py

Exploring Universal Intrinsic Task Subspace via Prompt Tuning

LoRA: Low-Rank Adaptation of Large Language Models

QLoRA: Efficient Finetuning of Quantized LLMs