# FlashAttention @ ICML 2022

## Table of Contents

paper link : https://www.youtube.com/watch?v=FThvfkXWqtE

TL;DR: accelerating transformer by considering memory hierarchy, without using sparsification.

## Challenges

- compute softmax reduction without access to full input.
- backward propagation without storing intermediate attention matrix.

## Solution

- Load block by block, so that we only access SRAM when computing attention.
- What about softmax?

- Don't save attention because recompute it when back propagation.

### Tiling

The figure below showing how to tile the matrices and perform self-attention on the tiled block.

However, `softmax`

value is only available after we traverse the whole row of the attention matrix.

What's the solution? Suppose we tile vector with size 4 by factor 2, use some simple math:

\[ \frac{e^{x_1}}{e^{x_1} + e^{x_2} + e^{x_3} + e^{x_4}} = \frac{e^{x_1}}{e^{x_1} + e^{x_2}} \cdot \frac{e^{x_1} + e^{x_2}}{e^{x_1} + e^{x_2} + e^{x_3} + e^{x_4}} \]

where \(\frac{e^{x_1}}{e^{x_1} + e^{x_2}}\) is the tile-wise `softmax`

results. We can store the value of \(e^{x_1} + e^{x_2}\) while computing the \(\textrm{softmax}(QK^T) V\) corresponding to the first tile.

After we finished computations of all tiles, we can scale the partial result of each tile by a factor.

There are also some numerical issues about `softmax`

we need to care about.

### Recomputation

Recomputation increases some GFLOPS, however, running time decreases because of reduced HBM access.

### My Two Cents

I enjoy reading the paper: idea is intuitive and solution is concise.

I'm questioning myself whether it necessary to make transformers sparse: even with the help of block-sparsity, sparse transformer variants has nearly the worst memory access pattern(nearly no re-use). The only way to make sparse transformer efficient is to make the cache large, architectures such as GraphCore might help (I'm not quite sure).

p.s. Junru suggest a paper called Online Softmax to me, which is a more comprehensive explanation about the softmax decomposition. (Can we propose some rules to automatically tile fused kernels? sounds interesting).