FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness
FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness
Abstract
Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. We argue that a missing principle is making attention algorithms IO-aware- accounting for reads and writes between levels of GPU memory. We propose FLASHATTENTION, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory and GPU on-chip SRAM. We analyze the IO complexity of FLASHATTENTION, showing that it requires fewer high bandwidth memory accesses than standard attention, and is optimal for a range of SRAM sizes. We also extend FLASHATTENTION to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method. FLASHATTENTION trains Transformers faster than existing baselines: fifteen percent end-to-end wall-clock speedup on BERT-large (sequence length five hundred twelve) compared to the MLPerf one point one training speed record, three times speedup on GPT-two (sequence length one thousand), and two point four times speedup on long-range arena (sequence length one thousand to four thousand). FLASHATTENTION and block-sparse FLASHATTENTION enable longer context in Transformers, yielding higher quality models (zero point seven better perplexity on GPT-two and six point four points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (sequence length sixteen thousand, sixty-one point four percent accuracy) and Path-two hundred fifty-six (sequence length sixty-four thousand, sixty-three point one percent accuracy).
One Introduction
One Introduction
Transformer models have emerged as the most widely used architecture in applications such as natural language processing and image classification. Transformers have grown larger and deeper, but equipping them with longer context remains difficult, since the self-attention module at their heart has time and memory complexity quadratic in sequence length. An important question is whether making attention faster and more memory-efficient can help Transformer models address their runtime and memory challenges for long sequences.
Many approximate attention methods have aimed to reduce the compute and memory requirements of attention. These methods range from sparse-approximation to low-rank approximation, and their combinations. Although these methods reduce the compute requirements to linear or near-linear in sequence length, many of them do not display wall-clock speedup against standard attention and have not gained wide adoption. One main reason is that they focus on FLOP reduction (which may not correlate with wall-clock speed) and tend to ignore overheads from memory access (IO).
In this paper, we argue that a missing principle is making attention algorithms IO-aware- that is, carefully accounting for reads and writes to different levels of fast and slow memory (e.g., between fast GPU on-chip SRAM and relatively slow GPU high bandwidth memory, or Figure one left). On modern GPUs, compute speed has out-paced memory speed, and most operations in Transformers are bottlenecked by memory accesses. IO-aware algorithms have been critical for similar memory-bound operations, when reading and writing data can account for a large portion of the runtime- such as database joins, image processing, numerical linear algebra, and more. However, common Python interfaces to deep learning such as PyTorch and Tensorflow do not allow fine-grained control of memory access.
We propose FLASHATTENTION, a new attention algorithm that computes exact attention with far fewer memory accesses. Our main goal is to avoid reading and writing the attention matrix to and from high bandwidth memory. This requires (i) computing the softmax reduction without access to the whole input (ii) not storing the large intermediate attention matrix for the backward pass. We apply two well-established techniques to address these challenges. (i) We restructure the attention computation to split the input into blocks and make several passes over input blocks, thus incrementally performing the softmax reduction (also known as tiling). (ii) We store the softmax normalization factor from the forward pass to quickly recompute attention on-chip in the backward pass, which is faster than the standard approach of reading the intermediate attention matrix from high bandwidth memory. We implement FLASHATTENTION in CUDA to achieve fine-grained control over memory access and fuse all the attention operations into one GPU kernel. Even with the increased FLOPs due to recomputation, our algorithm both runs faster (up to seven point six times on GPT-two, Figure one right) and uses less memory- linear in sequence length- than standard attention, thanks to the massively reduced amount of high bandwidth memory access.
We analyze the IO complexity of FLASHATTENTION, proving that it requires O of N squared d squared M minus one high bandwidth memory accesses where d is the head dimension and M is the size of SRAM, as compared to Q of N d plus N squared of standard attention. For typical values of d and M, FLASHATTENTION requires many times fewer high bandwidth memory accesses compared to standard attention (up to nine times fewer, as shown in Fig. two). Moreover, we provide a lower bound, showing that no exact attention algorithm can asymptotically improve on the number of high bandwidth memory accesses over all SRAM sizes.
We also show that FLASHATTENTION can serve as a useful primitive for realizing the potential of approximate attention algorithms by overcoming their issues with memory access overhead. As a proof of concept, we implement block-sparse FLASHATTENTION, a sparse attention algorithm that is two to four times faster than even FLASHATTENTION, scaling up to sequence length of sixty-four thousand. We prove that block-sparse FLASHATTENTION has better IO complexity than FLASHATTENTION by a factor proportional to the sparsity ratio. We discuss further extensions to other operations (attention on multi-GPU, kernel regression, block-sparse matrix multiply) in Section five. We open-source FLASHATTENTION to make it easier to build on this primitive.
We empirically validate that FLASHATTENTION speeds up model training and improves model quality by modeling longer context. We also benchmark the runtime and memory footprint of FLASHATTENTION and block-sparse FLASHATTENTION compared to prior attention implementations.
Faster Model Training. FLASHATTENTION trains Transformer models faster in wall-clock time. We train BERT-large (sequence length five hundred twelve) fifteen percent faster than the training speed record in MLPerf one point one, GPT-two (sequence length one thousand) three times faster than baseline implementations from HuggingFace and Megatron-LM, and long-range arena (sequence length one thousand to four thousand) two point four times faster than baselines.
Higher Quality Models. FLASHATTENTION scales Transformers to longer sequences, which improves their quality and enables new capabilities. We observe a zero point seven improvement in perplexity on GPT-two and six point four points of lift from modeling longer sequences on long-document classification. FLASHATTENTION enables the first Transformer that can achieve better-than-chance performance on the Path-X challenge, solely from using a longer sequence length (sixteen thousand). Block-sparse FLASHATTENTION enables a Transformer to scale to even longer sequences (sixty-four thousand), resulting in the first model that can achieve better-than-chance performance on Path-two hundred fifty-six.
Benchmarking Attention. FLASHATTENTION is up to three times faster than the standard attention implementation across common sequence lengths from one hundred twenty-eight to two thousand and scales up to sixty-four thousand. Up to sequence length of five hundred twelve, FLASHATTENTION is both faster and more memory-efficient than any existing attention method, whereas for sequence length beyond one thousand, some approximate attention methods start to become faster. On the other hand, block-sparse FLASHATTENTION is faster than all existing approximate attention methods that we know of.
Two Background
We provide some background on the performance characteristics of common deep learning operations on modern hardware, GPUs. We also describe the standard implementation of attention.