# FlashAttention @ ICML 2022

TL;DR: accelerating transformer by considering memory hierarchy, without using sparsification.

## Challenges

2. backward propagation without storing intermediate attention matrix.

## Solution

1. Load block by block, so that we only access SRAM when computing attention.
2. Don't save attention because recompute it when back propagation.

### Tiling

The figure below showing how to tile the matrices and perform self-attention on the tiled block.

However, softmax value is only available after we traverse the whole row of the attention matrix.

What's the solution? Suppose we tile vector with size 4 by factor 2, use some simple math:

$\frac{e^{x_1}}{e^{x_1} + e^{x_2} + e^{x_3} + e^{x_4}} = \frac{e^{x_1}}{e^{x_1} + e^{x_2}} \cdot \frac{e^{x_1} + e^{x_2}}{e^{x_1} + e^{x_2} + e^{x_3} + e^{x_4}}$

where $$\frac{e^{x_1}}{e^{x_1} + e^{x_2}}$$ is the tile-wise softmax results. We can store the value of $$e^{x_1} + e^{x_2}$$ while computing the $$\textrm{softmax}(QK^T) V$$ corresponding to the first tile.

After we finished computations of all tiles, we can scale the partial result of each tile by a factor.

There are also some numerical issues about softmax we need to care about.

### Recomputation

Recomputation increases some GFLOPS, however, running time decreases because of reduced HBM access.

### My Two Cents

I enjoy reading the paper: idea is intuitive and solution is concise.

I'm questioning myself whether it necessary to make transformers sparse: even with the help of block-sparsity, sparse transformer variants has nearly the worst memory access pattern(nearly no re-use). The only way to make sparse transformer efficient is to make the cache large, architectures such as GraphCore might help (I'm not quite sure).

p.s. Junru suggest a paper called Online Softmax to me, which is a more comprehensive explanation about the softmax decomposition. (Can we propose some rules to automatically tile fused kernels? sounds interesting).

Email: expye@outlook.com

Date: 2022-08-11 Thu 00:00