Fine-Tuning Llama 3.2 11B for Extractive Question Answering: A Comprehensive Guide
Large Language Models (LLMs) are powerful tools that can perform a wide range of natural language processing tasks. However, due to their generic and broadly focused training, they may not always perform optimally on specific tasks. Fine-tuning is a technique that allows us to adapt a pre-trained LLM to a specific task, such as extractive question answering, without altering the original weights. In this article, we will explore how to fine-tune Llama 3.2 11B using the Q-LoRA technique and demonstrate its performance boost on the SQuAD v2 dataset.
What is LoRA?
LoRA (Low-Rank Adaption) is a technique used to add new weights to an existing model to modify its behavior without changing the original weights. It involves adding new "adapter" weights that modify the output of certain layers, which are modified during the training process while the original weights remain the same. By freezing the original weights, LoRA ensures that the model retains its pre-trained knowledge while adding new, task-specific capabilities through the adapter weights.
Defining the Experiment
We will fine-tune Llama 3.2 11B for extractive question answering using the SQuAD v2 dataset in this experiment. The goal is to train the model to extract specific portions of text that directly answer a user's question without summarizing or rephrasing.
System Environment
This experiment was run on a Google Colab platform with an A100 GPU. The code is written in Python and utilizes the Hugging Face Transformers library.
Installing Packages
!pip install -U transformers peft bitsandbytes datasets trl evaluate bert_score
Loading Data
We will use the SQuAD v2 dataset for training and evaluation.
from datasets import load_dataset
ds = load_dataset("squad_v2")
print(ds)
Output:
DatasetDict({
train: Dataset({
features: ['id', 'title', 'context', 'question', 'answers'],
num_rows: 130319
})
validation: Dataset({
features: ['id', 'title', 'context', 'question', 'answers'],
num_rows: 11873
})
})
Data Preparation
We will split the dataset into training, validation, and test sets and convert the samples into a format suitable for Llama.
num_training_samples = 15000
num_test_samples = 750
num_validation_samples = 1000
training_samples = ds['train'].select([i for i in range(num_training_samples)])
test_samples = ds['train'].select([i for i in range(num_training_samples, num_training_samples+num_test_samples)])
validation_samples = ds['validation'].select([i for i in range(num_validation_samples)])
def convert_squad_sample_to_llama_conversation(sample):
#...
return {"text": sample_conversation, "messages": messages, "answer": answer}
conversation_training_samples = training_samples.map(convert_squad_sample_to_llama_conversation)
conversation_test_samples = test_samples.map(convert_squad_sample_to_llama_conversation)
conversation_validation_samples = validation_samples.map(convert_squad_sample_to_llama_conversation)
Model Preparation
We will load the Llama 3.2 11B model with 4-bit quantization and set up the LoRA config.
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config = bnb_config,
device_map="auto"
)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
from peft import LoraConfig
rank = 128
alpha = rank*2
peft_config = LoraConfig(
r=rank,
lora_alpha=alpha,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj', 'up_proj']
)
Training
We will use the SFTTrainer from the trl library to train the model.
from transformers import TrainingArguments
from trl import SFTTrainer
training_arguments = TrainingArguments(
output_dir=model_checkpoint_path,
optim='paged_adamw_32bit',
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
log_level='debug',
evaluation_strategy = "steps",
save_strategy='steps',
logging_steps=8,
eval_steps=8,
save_steps=8,
learning_rate=1e-4,
fp16=True,
num_train_epochs=4,
max_steps=120,
warmup_ratio=0.1,
load_best_model_at_end = True,
overwrite_output_dir = True,
lr_scheduler_type='linear',
)
trainer = SFTTrainer(
model=model,
train_dataset=conversation_training_samples,
eval_dataset=conversation_test_samples,
peft_config=peft_config,
dataset_text_field='text',
max_seq_length=1024,
tokenizer=tokenizer,
args=training_arguments
)
Evaluation
We will evaluate the model using the bert-score and exact-match metrics.
from evaluate import load
bert_model = "microsoft/deberta-v2-xxlarge-mnli"
bertscore = load("bertscore")
exact_match_metric = load("exact_match")
def get_bulk_predictions(pipe, samples):
#...
def get_base_and_tuned_bulk_predictions(samples):
#...
conversation_validation_samples = conversation_validation_samples.map(get_base_and_tuned_bulk_predictions, batched=True, batch_size=20)
base_predictions = conversation_validation_samples['base_prediction']
answers = conversation_validation_samples['answer']
base_validation_bert_score = bertscore.compute(predictions=base_predictions, references=answers, lang="en", model_type=bert_model, device="cuda:0")
baseline_exact_match_score = exact_match_metric.compute(predictions=base_predictions, references=answers)
trained_predictions = conversation_validation_samples['trained_prediction']
answers = conversation_validation_samples['answer']
trained_validation_bert_score = bertscore.compute(predictions=trained_predictions, references=answers, lang="en", model_type=bert_model, device="cuda:0")
tuned_exact_match_score = exact_match_metric.compute(predictions=trained_predictions, references=answers)
Results
The training process took around 1 hour on an A100 GPU. The results show a significant improvement in the model's performance on the validation set.
Metric | Base Model | Tuned Model |
BERT Score | 0.6469 | 0.7505 |
Exact Match | 0.116 | 0.418 |
Conclusion
This article demonstrated how to fine-tune Llama 3.2 11B for extractive question answering using the Q-LoRA technique. The results show a significant improvement in the model's performance on the validation set, with an increase in the BERT and exact match scores. This technique can be applied to other tasks and models, and we hope that this article serves as a comprehensive guide for future research and applications.