Finetuning Llama2 with LoRA — TorchTune documentation (2024)

This guide will teach you about LoRA, a parameter-efficient finetuning technique,and show you how you can use torchtune to finetune a Llama2 model with LoRA.If you already know what LoRA is and want to get straight to runningyour own LoRA finetune in torchtune, you can jump to LoRA finetuning recipe in torchtune.

What you will learn

  • What LoRA is and how it saves memory during finetuning

  • An overview of LoRA components in torchtune

  • How to run a LoRA finetune using torchtune

  • How to experiment with different LoRA configurations

Prerequisites

  • Be familiar with torchtune

  • Make sure to install torchtune

  • Make sure you have downloaded the Llama2-7B model weights

What is LoRA?

LoRA is an adapter-based method forparameter-efficient finetuning that adds trainable low-rank decomposition matrices to different layers of a neural network,then freezes the network’s remaining parameters. LoRA is most commonly applied totransformer models, in which case it is common to add the low-rank matricesto some of the linear projections in each transformer layer’s self-attention.

By finetuning with LoRA (as opposed to finetuning all model parameters),you can expect to see memory savings due to a substantial reduction in thenumber of parameters with gradients. When using an optimizer with momentum,like AdamW,you can expect to see further memory savings from the optimizer state.

Note

LoRA memory savings come primarily from gradient and optimizer states,so if your model’s peak memory comes in its forward() method, then LoRAmay not reduce peak memory.

How does LoRA work?

LoRA replaces weight update matrices with a low-rank approximation. In general, weight updatesfor an arbitrary nn.Linear(in_dim,out_dim) layer could have rank as high asmin(in_dim,out_dim). LoRA (and other related papers such as Aghajanyan et al.)hypothesize that the intrinsic dimensionof these updates during LLM fine-tuning can in fact be much lower.To take advantage of this property, LoRA finetuning will freeze the original model,then add a trainable weight update from a low-rank projection. More explicitly, LoRA trains twomatrices A and B. A projects the inputs down to a much smaller rank (often four or eight in practice), andB projects back up to the dimension output by the original linear layer.

The image below gives a simplified representation of a single weight update step from a full finetune(on the left) compared to a weight update step with LoRA (on the right). The LoRA matrices A and Bserve as an approximation to the full rank weight update in blue.

Finetuning Llama2 with LoRA — TorchTune documentation (1)

Although LoRA introduces a few extra parameters in the model forward(), only the A and B matrices are trainable.This means that with a rank r LoRA decomposition, the number of gradients we need to store reducesfrom in_dim*out_dim to r*(in_dim+out_dim). (Remember that in general ris much smaller than in_dim and out_dim.)

For example, in the 7B Llama2’s self-attention, in_dim=out_dim=4096 for the Q, K,and V projections. This means a LoRA decomposition of rank r=8 will reduce the number of trainableparameters for a given projection from \(4096 * 4096 \approx 15M\) to \(8 * 8192 \approx 65K\), areduction of over 99%.

Let’s take a look at a minimal implementation of LoRA in native PyTorch.

from torch import nn, Tensorclass LoRALinear(nn.Module): def __init__( self, in_dim: int, out_dim: int, rank: int, alpha: float, dropout: float ): # These are the weights from the original pretrained model self.linear = nn.Linear(in_dim, out_dim, bias=False) # These are the new LoRA params. In general rank << in_dim, out_dim self.lora_a = nn.Linear(in_dim, rank, bias=False) self.lora_b = nn.Linear(rank, out_dim, bias=False) # Rank and alpha are commonly-tuned hyperparameters self.rank = rank self.alpha = alpha # Most implementations also include some dropout self.dropout = nn.Dropout(p=dropout) # The original params are frozen, and only LoRA params are trainable. self.linear.weight.requires_grad = False self.lora_a.weight.requires_grad = True self.lora_b.weight.requires_grad = True def forward(self, x: Tensor) -> Tensor: # This would be the output of the original model frozen_out = self.linear(x) # lora_a projects inputs down to the much smaller self.rank, # then lora_b projects back up to the output dimension lora_out = self.lora_b(self.lora_a(self.dropout(x))) # Finally, scale by the alpha parameter (normalized by rank) # and add to the original model's outputs return frozen_out + (self.alpha / self.rank) * lora_out

There are some other details around initialization which we omit here, but if you’d like to know moreyou can see our implementation in LoRALinear.Now that we understand what LoRA is doing, let’s look at how we can apply it to our favorite models.

Applying LoRA to Llama2 models

With torchtune, we can easily apply LoRA to Llama2 with a variety of different configurations.Let’s take a look at how to construct Llama2 models in torchtune with and without LoRA.

from torchtune.models.llama2 import llama2_7b, lora_llama2_7b# Build Llama2 without any LoRA layersbase_model = llama2_7b()# The default settings for lora_llama2_7b will match those for llama2_7b# We just need to define which layers we want LoRA applied to.# Within each self-attention, we can choose from ["q_proj", "k_proj", "v_proj", and "output_proj"].# We can also set apply_lora_to_mlp=True or apply_lora_to_output=True to apply LoRA to other linear# layers outside of the self-attention.lora_model = lora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"])

Note

Calling lora_llama_2_7b alone will not handle the definition of which parameters are trainable.See below for how to do this.

Let’s inspect each of these models a bit more closely.

# Print the first layer's self-attention in the usual Llama2 model>>> print(base_model.layers[0].attn)CausalSelfAttention( (q_proj): Linear(in_features=4096, out_features=4096, bias=False) (k_proj): Linear(in_features=4096, out_features=4096, bias=False) (v_proj): Linear(in_features=4096, out_features=4096, bias=False) (output_proj): Linear(in_features=4096, out_features=4096, bias=False) (pos_embeddings): RotaryPositionalEmbeddings())# Print the same for Llama2 with LoRA weights>>> print(lora_model.layers[0].attn)CausalSelfAttention( (q_proj): LoRALinear( (dropout): Dropout(p=0.0, inplace=False) (lora_a): Linear(in_features=4096, out_features=8, bias=False) (lora_b): Linear(in_features=8, out_features=4096, bias=False) ) (k_proj): Linear(in_features=4096, out_features=4096, bias=False) (v_proj): LoRALinear( (dropout): Dropout(p=0.0, inplace=False) (lora_a): Linear(in_features=4096, out_features=8, bias=False) (lora_b): Linear(in_features=8, out_features=4096, bias=False) ) (output_proj): Linear(in_features=4096, out_features=4096, bias=False) (pos_embeddings): RotaryPositionalEmbeddings())

Notice that our LoRA model’s layer contains additional weights in the Q and V projections,as expected. Additionally, inspecting the type of lora_model andbase_model, would show that they are both instances of the same TransformerDecoder.(Feel free to verify this for yourself.)

Why does this matter? torchtune makes it easy to load checkpoints for LoRA directly from our Llama2model without any wrappers or custom checkpoint conversion logic.

# Assuming that base_model already has the pretrained Llama2 weights,# this will directly load them into your LoRA model without any conversion necessary.lora_model.load_state_dict(base_model.state_dict(), strict=False)

Note

Whenever loading weights with strict=False, you should verify that any missing or extra keys inthe loaded state_dict are as expected. torchtune’s LoRA recipes do this by default via e.g.torchtune.modules.peft.validate_state_dict_for_lora().

Once we’ve loaded the base model weights, we also want to set only LoRA parameters to trainable.

from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params# Fetch all params from the model that are associated with LoRA.lora_params = get_adapter_params(lora_model)# Set requires_grad=True on lora_params, and requires_grad=False on all others.set_trainable_params(lora_model, lora_params)# Print the total number of parameterstotal_params = sum([p.numel() for p in lora_model.parameters()])trainable_params = sum([p.numel() for p in lora_model.parameters() if p.requires_grad])print( f""" {total_params} total params, {trainable_params}" trainable params, {(100.0 * trainable_params / total_params):.2f}% of all params are trainable. """)6742609920 total params,4194304 trainable params,0.06% of all params are trainable.

Note

If you are directly using the LoRA recipe (as detailed here), you need only pass therelevant checkpoint path. Loading model weights and setting trainable parameters will be taken careof in the recipe.

LoRA finetuning recipe in torchtune

Finally, we can put it all together and finetune a model using torchtune’s LoRA recipe.Make sure that you have first downloaded the Llama2 weights and tokenizer by following these instructions.You can then run the following command to perform a LoRA finetune of Llama2-7B with two GPUs (each having VRAM of at least 16GB):

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora

Note

Make sure to point to the location of your Llama2 weights and tokenizer. This can be doneeither by adding checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_pathor by directly modifying the 7B_lora.yaml file. See our All About Configsfor more details on how you can easily clone and modify torchtune configs.

Note

You can modify the value of nproc_per_node depending on (a) the number of GPUs you have available,and (b) the memory constraints of your hardware.

The preceding command will run a LoRA finetune with torchtune’s factory settings, but we may want to experiment a bit.Let’s take a closer look at some of the lora_finetune_distributed config.

# Model Argumentsmodel: _component_: lora_llama2_7b lora_attn_modules: ['q_proj', 'v_proj'] lora_rank: 8 lora_alpha: 16...

We see that the default is to apply LoRA to Q and V projections with a rank of 8.Some experiments with LoRA have found that it can be beneficial to apply LoRA to all linear layers inthe self-attention, and to increase the rank to 16 or 32. Note that this is likely to increase our max memory,but as long as we keep rank<<embed_dim, the impact should be relatively minor.

Let’s run this experiment. We can also increase alpha (in general it is good practice to scale alpha and rank together).

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora \lora_attn_modules=['q_proj','k_proj','v_proj','output_proj'] \lora_rank=32 lora_alpha=64 output_dir=./lora_experiment_1

A comparison of the (smoothed) loss curves between this run and our baseline over the first 500 steps can be seen below.

Finetuning Llama2 with LoRA — TorchTune documentation (2)

Note

The above figure was generated with W&B. You can use torchtune’s WandBLoggerto generate similar loss curves, but you will need to install W&B and setup an account separately.

Trading off memory and model performance with LoRA

In the preceding example, we ran LoRA on two devices. But given LoRA’s low memory footprint, we can run fine-tuningon a single device using most commodity GPUs which support bfloat16 floating-point format. This can be done via the command:

tune run lora_finetune_single_device --config llama2/7B_lora_single_device

On a single device, we may need to be more cognizant of our peak memory. Let’s run a few experimentsto see our peak memory during a finetune. We will experiment along two axes:first, which model layers have LoRA applied, and second, the rank of each LoRA layer. (We will scalealpha in parallel to LoRA rank, as discussed above.)

To compare the results of our experiments, we can evaluate our models on truthfulqa_mc2, a task fromthe TruthfulQA benchmark for language models. For more details on how to run this and other evaluation taskswith torchtune’s EleutherAI evaluation harness integration, see our End-to-End Workflow Tutorial.

Previously, we only enabled LoRA for the linear layers in each self-attention module, but in fact there are other linearlayers we can apply LoRA to: MLP layers and our model’s final output projection. Note that for Llama-2-7B the final outputprojection maps to dimension 32000 (instead of 4096 as in the other linear layers), so enabling LoRA for this layer will increaseour peak memory a bit more than the other layers. We can make the following changes to our config:

# Model Argumentsmodel: _component_: lora_llama2_7b lora_attn_modules: ['q_proj', 'k_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: True...

Note

All the finetuning runs below use the llama2/7B_lora_single_deviceconfig, which has a default batch size of 2. Modifying the batch size (or other hyperparameters, e.g. the optimizer) will impact both peak memoryand final evaluation results.

LoRA Layers

Rank

Alpha

Peak Memory

Accuracy (truthfulqa_mc2)

Q and V only

8

16

15.57 GB

0.475

all layers

8

16

15.87 GB

0.508

Q and V only

64

128

15.86 GB

0.504

all layers

64

128

17.04 GB

0.514

We can see that our baseline settings give the lowest peak memory, but our evaluation performance is relatively lower.By enabling LoRA for all linear layers and increasing the rank to 64, we see almost a 4% absolute improvementin our accuracy on this task, but our peak memory also increases by about 1.4GB. These are just a couple simpleexperiments; we encourage you to run your own finetunes to find the right tradeoff for your particular setup.

Additionally, if you want to decrease your model’s peak memory even further (and still potentially achieve similarmodel quality results), you can check out our QLoRA tutorial.

Finetuning Llama2 with LoRA — TorchTune  documentation (2024)

References

Top Articles
Latest Posts
Article information

Author: Jamar Nader

Last Updated:

Views: 6165

Rating: 4.4 / 5 (75 voted)

Reviews: 82% of readers found this page helpful

Author information

Name: Jamar Nader

Birthday: 1995-02-28

Address: Apt. 536 6162 Reichel Greens, Port Zackaryside, CT 22682-9804

Phone: +9958384818317

Job: IT Representative

Hobby: Scrapbooking, Hiking, Hunting, Kite flying, Blacksmithing, Video gaming, Foraging

Introduction: My name is Jamar Nader, I am a fine, shiny, colorful, bright, nice, perfect, curious person who loves writing and wants to share my knowledge and understanding with you.