FlashAttention @ ICML 2022

Table of Contents

paper link : https://www.youtube.com/watch?v=FThvfkXWqtE

TL;DR: accelerating transformer by considering memory hierarchy, without using sparsification.

Challenges

  1. compute softmax reduction without access to full input.
  2. backward propagation without storing intermediate attention matrix.

Solution

  1. Load block by block, so that we only access SRAM when computing attention.
    1. What about softmax?
  2. Don't save attention because recompute it when back propagation.

Tiling

The figure below showing how to tile the matrices and perform self-attention on the tiled block.

2022-08-11_17-52-43_screenshot.png

However, softmax value is only available after we traverse the whole row of the attention matrix.

What's the solution? Suppose we tile vector with size 4 by factor 2, use some simple math:

\[ \frac{e^{x_1}}{e^{x_1} + e^{x_2} + e^{x_3} + e^{x_4}} = \frac{e^{x_1}}{e^{x_1} + e^{x_2}} \cdot \frac{e^{x_1} + e^{x_2}}{e^{x_1} + e^{x_2} + e^{x_3} + e^{x_4}} \]

where \(\frac{e^{x_1}}{e^{x_1} + e^{x_2}}\) is the tile-wise softmax results. We can store the value of \(e^{x_1} + e^{x_2}\) while computing the \(\textrm{softmax}(QK^T) V\) corresponding to the first tile.

After we finished computations of all tiles, we can scale the partial result of each tile by a factor.

There are also some numerical issues about softmax we need to care about.

Recomputation

Recomputation increases some GFLOPS, however, running time decreases because of reduced HBM access.

2022-08-11_18-11-50_screenshot.png

My Two Cents

I enjoy reading the paper: idea is intuitive and solution is concise.

I'm questioning myself whether it necessary to make transformers sparse: even with the help of block-sparsity, sparse transformer variants has nearly the worst memory access pattern(nearly no re-use). The only way to make sparse transformer efficient is to make the cache large, architectures such as GraphCore might help (I'm not quite sure).

p.s. Junru suggest a paper called Online Softmax to me, which is a more comprehensive explanation about the softmax decomposition. (Can we propose some rules to automatically tile fused kernels? sounds interesting).

Author: expye(Zihao Ye)

Email: expye@outlook.com

Date: 2022-08-11 Thu 00:00

Last modified: 2022-09-23 Fri 00:04

Licensed under CC BY-NC 4.0