1. Flash-Attention2
1.1概述
与Flash Attention相比,Flash Attention2的速度提高了约2倍,在A100上达到理论最大FLOP/s的50-73%,1版本只能做到理论最大的25-40%。
(1)调整算法以减少非矩阵乘法FLOP的数量,GPU上有专门的计算单元,这使得矩阵乘法更快
(2)在不同的线程块上并行化注意力计算,以增加显存占用率
(3)在每个线程块内,将工作划分为不同的线程束,以减少通过共享内存的通信
1.2技术实现
算法实现上,交换了内外循环,通过logsumexp保存了中间状态值,便于后续进行反向传播,同时,算法还在序列上进行了并行计算,提高了显存占用率;
在正向传播和反向传播过程中,线程块(worker)并行计算的方向是不一样的,正向传播是对矩阵按行划分工作,反向传播是按列进行划分线程块进行工作。
在每个线程块下,又划分了4到8个线程束对序列进行并行化计算,Flash-Attention是按照K进行线程束划分,而这样做需要进行不断的通信进行共享内存的读写;Flash-Attention2则对此进行了优化,将线程束划分优化改为在Q矩阵上的划分,从而减少了共享内存的反复读写(对比两个算法计算softmax的过程,Flash-Attention2不用再反复读写中间量m和l)。
2. 总结
● Flash-Attention2是一种IO感知的精确注意力算法,通过降低HBM的读写,从而提升了速度;
● Flash-Attention2通过降低长序列的计算复杂度,从而提升了序列支持长度;
● Flash-Attention可以加速2~4倍,而Flash-Attention2比Flash-Attention速度快1.7~3倍。