QLoRA 微调完整流程:环境搭建到模型合并
FreeGuideOnline
最新
2026-06-22
bash pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 pip install bitsandbytes==0.41.0 # 务必版本 ≥0.41.0 pip install transformers peft accelerate datasets pip install trl # 如果你用 SFTTrainer pip install sentencepiece # 如使用 Llama 等 tokenizer
**验证安装**:在 Python 中执行 `import bitsandbytes as bnb; print(bnb.__version__)` 无报错即可。
### 3. 加载量化后的基础模型
用 `BitsAndBytesConfig` 配置 4 位量化,并兼容 NF4 数据类型(QLoRA 论文推荐)。注意:模型需支持 HF 的 `AutoModelForCausalLM`,例如 Llama‑2、Mistral、Falcon 等。
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
model_name = "meta-llama/Llama-2-7b-hf" # 你可换为其他模型
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True, # 双重量化,进一步节省显存
bnb_4bit_quant_type="nf4", # NormalFloat 4 位
bnb_4bit_compute_dtype=torch.bfloat16 # 计算时用 bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto", # 自动分配权重到 GPU/CPU
trust_remote_code=True # 部分模型需要
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token # 设置 pad_token,避免警告
4. 准备微调数据集
QLoRA 通常用于指令微调,数据集需整理成 {“instruction”: …, “input”: …, “output”: …} 的格式,或直接文本行。这里以 Alpaca 格式为例,用 datasets 加载并处理后调用 tokenizer。
from datasets import load_dataset
dataset = load_dataset("yahma/alpaca-cleaned", split="train") # 示例数据集
def format_prompt(sample):
if sample["input"]:
prompt = f"""Below is an instruction that describes a task, paired with an input.
### Instruction:
{sample["instruction"]}
### Input:
{sample["input"]}
### Response:
{sample["output"]}"""
else:
prompt = f"""Below is an instruction that describes a task.
### Instruction:
{sample["instruction"]}
### Response:
{sample["output"]}"""
return prompt
def tokenize_function(examples):
full_texts = [format_prompt(ex) for ex in examples]
tokenized = tokenizer(
full_texts, truncation=True, max_length=512, padding="max_length",
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
return tokenized
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
5. 配置 LoRA 适配器
依托 peft 库,设置 LoRA 的秩(r)、alpha、目标模块等。对所有线性层(q_proj, v_proj 等)添加适配器能获得最佳效果。
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
# 让模型适配 4bit 训练,添加梯度检查点等
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=16, # 秩
lora_alpha=32, # 缩放因子
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters() # 查看可训练参数量,通常 < 原模型的 1%
6. 开始训练
推荐使用 transformers.Trainer 或 trl.SFTTrainer。以下使用 SFTTrainer,它简化了数据加载与格式化。
from trl import SFTTrainer
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./qlora-llama2-7b",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
fp16=True, # 显卡若支持 Ampere 架构可用 bf16
logging_steps=10,
save_strategy="epoch",
dataloader_num_workers=2,
optim="paged_adamw_8bit", # 节省显存的优化器
report_to="none" # 若需用 wandb 可修改
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
dataset_text_field="input_ids", # 注意 SFTTrainer 直接读取已 tokenized 的字段
max_seq_length=512,
packing=False, # 是否将多个短样本拼接,减少 padding
)
trainer.train()
7. 保存适配器与合并模型
训练完成后,先保存轻量级的 LoRA 适配器(通常只有几十 MB):
model.save_pretrained("./qlora-llama2-7b-adapter")
tokenizer.save_pretrained("./qlora-llama2-7b-adapter")
将适配器合并回基础模型,并保存为完整权重(以便于部署或推理):
import torch
from peft import PeftModel
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.bfloat16, # 合并后可全精度存放
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# 加载 LoRA 适配器
model = PeftModel.from_pretrained(base_model, "./qlora-llama2-7b-adapter")
# 合并并卸载 LoRA 层,返回普通模型结构
merged_model = model.merge_and_unload()
# 保存合并后的完整模型
merged_model.save_pretrained("./qlora-llama2-7b-merged")
tokenizer.save_pretrained("./qlora-llama2-7b-merged")