1. 前置知识
HBM:High Bandwidth Memory,即高带宽内存,是一款新型的CPU/GPU内存芯片。将很多个DDR(双数据率同步动态随机存储器)芯片堆叠在一起后和GPU封装在一起,实现大容量、高位宽的DDR组合阵列。
线程束和线程块:一个线程束由32个连续的线程组成,在一个线程束中,所有的线程按照单指令多线程(SIMT)方式执行,线程块是一维线程束的集合。
2. Flash-Attention
1. 概述
一种IO-aware的精确注意力算法,核心在于减少内存访问。(训练框架如pytorch不能对内存进行细粒度的控制)
(1) 在不访问整个输入的情况下计算softmax;
(2)不存储用于反向传播的大的中间注意力矩阵。
解决方案:
(1)重组注意力计算,将输入划分为块,并在输入块上进行多次遍历,从而逐步执行softmax;
(2)存储来自前向传递的softmax归一化因子,以在反向传播中快速重新计算片上注意力。
2. 技术实现
Flash Attention的思想是降低HBM访问次数,从而降低所需时间,提高训练性能,核心在于tiling,即平铺。
标准的Attention实现,每次计算时都会访问HBM去load对应的Qi、Kj、Vj,这样一来的HBM访问复杂度约为O(Nd+N^2)
标准Attention计算过程忽略了HBM访问的时间消耗,如果能够删去或减少HBM的访问次数,那么效率提升也就水到渠成了。
而在Attention计算过程中最耗费时间的就是softmax部分(考虑softmax的分母部分),如果能对softmax进行拆解,那么就不需要每次都将QK计算结果再存储回HBM,最后进行归一化时再load一遍。
Flash Attention充分利用了比HBM更小,但速度更快的SRAM执行中间步骤,减少中间结果反复在HBM上存储和读取。因此,Flash Attention算法选择对Q、K、V进行分块,块的大小与SRAM的大小M和Q维度d有关。
算法一开始将Q、K、V矩阵进行了分块,进入循环时,每一次load对应的Qi、Kj、Vj矩阵块,并计算中间结果m和l,m是逐行最大值,l是逐行累加(softmax分母部分),在按块进行计算时,更新到新的块时,会用旧值对新值进行更新。最后,利用对角阵的特性对输出O进行更新,并返回输出结果。
Flash Attention算法并不像传统思路一样,去减少运算次数以降低时间消耗,而是细粒度的分析了一个运算过程中时间消耗,从更底层去优化运算过程。虽然相较于标准Attention,Flash Attention一定程度上增加了计算FLOP,但是通过拆解计算,显著降低了IO次数,从而提升了性能。
3. 块稀疏-Flash Attention
算法可以很容易地扩展到block-sparse FlashAttention,这是一种比Flash Attention更快的稀疏注意力算法,通过使用一个块形式的掩码矩阵,可以跳过嵌套的for循环中的读写。
4. 优化情况
以MLPerf 1.1的记录作为baseline
● BERT-large(序列长度512)训练加速了15%
● GPT-2上(序列长度1K)训练加速了3倍
● long-range arena(序列长1K-4K)训练加速了2.4倍。
Flash Attention因为可以支持更长的序列长度,从而也产生了更高质量的模型。