LoRa原理
LoRA 的思想很简单:
● 在原始 PLM (Pre-trained Language Model) 旁边增加一个旁路,做一个降维再升维的操作,来模拟所谓的intrinsic rank。
● 训练的时候固定 PLM 的参数,只训练低秩矩阵 𝐴 与矩阵 𝐵 。而模型的输入输出维度不变,输出时将 𝐵𝐴 与 PLM 的参数叠加。
● 用随机高斯分布初始化 𝐴 ,用 0 矩阵初始化 𝐵 ,保证训练的开始此旁路矩阵依然是 0 矩阵。
假设要在下游任务微调一个预训练语言模型(如 GPT-3),则需要更新预训练模型参数,公式表示如下:W0+\delta{W}
W0是预训练模型初始化的参数,\delta{W}就是需要更新的参数。如果是全参数微调,则它的参数量等于W0,如果是 GPT-3,则\delta{W}≈175B )。从这可以看出要全参数微调大语言模型,代价是非常高的
显存和计算量分析
显存分析
主干模型部分
首先主干模型的权重都要存储在显存中,这部分显存无法省掉
其次,虽然只对 LoRA 部分的模型进行优化,但是想要求 LoRA 部分的梯度,那么主干的梯度也是必须要求解出来的,所以主干模型的梯度是必须要求的。
由于不需要优化主干模型,所以主干模型对应的优化器不需要存储,这部分显存可以节省。
LoRa模型部分
LoRA 模型的权重、梯度、优化器状态都需要存储,这个是没有疑问的。
计算量分析
涉及到计算的主要分为前向传播、反向传播、优化器更新权重,这三部分。下面也主要是看这三部分中哪部分可以省掉。
结论:LoRA 在计算量上和全量参数微调基本是一致的。
在实际训练中,还是能够感受到使用 LoRA 时速度变快了,这个的原因一般有:
(1)使用 LoRA 时会对主干模型做 int8 甚至是 int4 的量化,使得主干模型的前向传播和反向传播耗时减少;
(2)多卡训练(数据并行)时,卡间通信只需要同步 LoRA 模型部分的梯度,大大减小的通信的压力,也会使用总训练速度变快。