QLoRA: efficiently LLM Fine-Tuning
Parameter-efficient training (PEFT) techniques offer a way to fine-tune large language models (LLMs) on custom datasets with minimal computational resources. Those techniques only update a small number of parameters during the fine-tuning process, and freeze the rest.
LoRA is a PET technique that fine-tunes two smaller matrices that approximate the model. Its Quantized version, QLoRA, uses a smaller precision on those two proxy matrices.
Here’s a broad overview of how to use QLoRA to fine-tune an LLM.
1. Install Required Libraries
- Key Libraries:
- bitsandbytes for efficient model loading and quantization.
- transformers for pre-trained models and utilities.
- peft for parameter-efficient fine-tuning.
2. Prepare the Model for QLoRA
- Configure the model to use 4-bit quantization using BitsAndBytesConfig.
- Load a pre-trained model in 4-bit format using the transformers library and the 4-bit config.
5. Tokenization and Pre-processing
- Configure the tokenizer with necessary settings for optimal memory usage.
- Pre-process the dataset to match the model’s expected input format (completion, summarization, sentiment analysis, etc).
- Tokenize the pre-processed dataset, batch and shuffle it.
6. Fine-Tuning with QLoRA
- Prepare the model for QLoRA using the prepare_model_for_kbit_training() function of the peft library .
- Define a LoraConfig for fine-tuning, specifying ranks, modules, and other parameters using the peft library, and update the model using this config.
- Use transformers.Trainer with custom training arguments to fine-tune the adapter.
Code Snippet for Fine-Tuning Setup:
from transformers import (TrainingArguments, Trainer)
from peft import (LoraConfig, get_peft_model
prepare_model_for_kbit_training)
# Prep QLoRA model
model = prepare_model_for_kbit_training(
pretrained_model)
# Set LoRA config
lora_config = LoraConfig(r=32, lora_alpha=32,
target_modules=['q_proj', 'k_proj',
'v_proj', 'dense'])
# Get PEFT model
peft_model = get_peft_model(
model, lora_config)
# Training args
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs')
# Init Trainer
trainer = Trainer(
model=peft_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset)
# Fine-tune
trainer.train()