sharpbyte.dev
← Fine-tuning
Fine-tuning · topic 5 of 11

How LoRA works

Factor ΔW ≈ AB with rank r, train only A and B, interpret as a expressivity vs rank trade-off grid.

Why ΔW factors into A and B

How does LoRA work?

Now, you might be thinking…

But how can we even add two matrices if both have different dimensions?

It’s true, we can’t do that.

More specifically, during fine-tuning, the weight matrix W is frozen, so it does not receive any gradient updates. Thus, all gradient updates are redirected to the ΔW matrix. But to ensure that ΔW and W remain additive to generate a final representation for the fine-tuned model, the ΔW matrix is split into a product of two low-rank matrices A and B, which contain the trainable parameters.

Low-rank factorization ΔW = AB with trainable A, B.
Low-rank factorization ΔW = AB with trainable A, B.

As discussed earlier, the dimensions of W are d × k:

W shape d×k implies ΔW must add compatibly.
W shape d×k implies ΔW must add compatibly.

Tiny factors, full product rank

Thus, the dimensions of ΔW must also be d × k. But this does not mean that the total trainable parameters in A and B must also align with the dimensions of ΔW.

Instead, A and B can be extremely small matrices, and the only thing we must ensure is that their product results in a matrix which has dimensions d × k.

Thus:

  • The dimension of matrix A is set to d × r.
  • The dimension of matrix B is set to r × k.
  • If we check their product, that is indeed d × k.

During training, only matrix A and B are trained while the entire network’s weights are kept fixed.

This is how LoRA works.

To discuss it more formally, LoRA questions the very idea of full-model fine-tuning by asking two questions:

  • Do we really need to fine-tune all the parameters in the original model?
  • How expressive are the parameters of the original model (or matrix rank)?

This can be plotted as a 2D grid, as shown below:

2D grid of LoRA configurations; corner = full fine-tuning.
2D grid of LoRA configurations; corner = full fine-tuning.

In the above image, every point denotes a possible LoRA configuration. Also, the upper right corner refers to full fine-tuning.

Experimentally, it is observed that an ideal configuration is located in the bottom left corner of the above grid, which means that we do not need to train all the model parameters.

Now that we understand how LoRA works, let’s proceed with implementing LoRA.

Implementation. While a few open-source implementations are already available for LoRA, yet, we shall implement it from scratch using PyTorch so that we get a better idea of the practical details.

As discussed above, a typical LoRA layer comprises two matrices, A and B. These have been implemented in the LoRAWeights class below along with the forward pass:

LoRAWeights class and forward pass (PyTorch) from the deck.
LoRAWeights class and forward pass (PyTorch) from the deck.

Key takeaways

  • LoRA trains small A and B whose product matches the additive shape of ΔW.
  • Empirically, strong results often live at low rank—most weights stay frozen.