FlashAttention @ ICML 2022
Table of Contents
paper link : https://www.youtube.com/watch?v=FThvfkXWqtE
TL;DR: accelerating transformer by considering memory hierarchy, without using sparsification.
Challenges
- compute softmax reduction without access to full input.
- backward propagation without storing intermediate attention matrix.
Solution
- Load block by block, so that we only access SRAM when computing attention.
- What about softmax?
- Don't save attention because recompute it when back propagation.
Tiling
The figure below showing how to tile the matrices and perform self-attention on the tiled block.
However, softmax
value is only available after we traverse the whole row of the attention matrix.
What's the solution? Suppose we tile vector with size 4 by factor 2, use some simple math:
\[ \frac{e^{x_1}}{e^{x_1} + e^{x_2} + e^{x_3} + e^{x_4}} = \frac{e^{x_1}}{e^{x_1} + e^{x_2}} \cdot \frac{e^{x_1} + e^{x_2}}{e^{x_1} + e^{x_2} + e^{x_3} + e^{x_4}} \]
where \(\frac{e^{x_1}}{e^{x_1} + e^{x_2}}\) is the tile-wise softmax
results. We can store the value of \(e^{x_1} + e^{x_2}\) while computing the \(\textrm{softmax}(QK^T) V\) corresponding to the first tile.
After we finished computations of all tiles, we can scale the partial result of each tile by a factor.
There are also some numerical issues about softmax
we need to care about.
Recomputation
Recomputation increases some GFLOPS, however, running time decreases because of reduced HBM access.
My Two Cents
I enjoy reading the paper: idea is intuitive and solution is concise.
I'm questioning myself whether it necessary to make transformers sparse: even with the help of block-sparsity, sparse transformer variants has nearly the worst memory access pattern(nearly no re-use). The only way to make sparse transformer efficient is to make the cache large, architectures such as GraphCore might help (I'm not quite sure).
p.s. Junru suggest a paper called Online Softmax to me, which is a more comprehensive explanation about the softmax decomposition. (Can we propose some rules to automatically tile fused kernels? sounds interesting).