<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Design | Shaoyang Cui</title><link>https://spidermonk7.github.io/tags/design/</link><atom:link href="https://spidermonk7.github.io/tags/design/index.xml" rel="self" type="application/rss+xml"/><description>Design</description><generator>Hugo Blox Builder (https://hugoblox.com)</generator><language>en-us</language><lastBuildDate>Thu, 19 Mar 2026 00:00:00 +0000</lastBuildDate><image><url>https://spidermonk7.github.io/media/icon_hu7729264130191091259.png</url><title>Design</title><link>https://spidermonk7.github.io/tags/design/</link></image><item><title>Learning Notes - the Flash-Attention</title><link>https://spidermonk7.github.io/post/understanding_fa/</link><pubDate>Thu, 19 Mar 2026 00:00:00 +0000</pubDate><guid>https://spidermonk7.github.io/post/understanding_fa/</guid><description>&lt;h1 id="motivation">Motivation&lt;/h1>
&lt;p>Everything started with a phone call. My girlfriend asked me:&lt;/p>
&lt;p>&amp;ldquo;Honey, you&amp;rsquo;ve mentioned that you successfully installed &lt;em>FlashAttention&lt;/em> on your new cluster. What is that?&amp;rdquo;&lt;/p>
&lt;p>&amp;ldquo;Oh, it&amp;rsquo;s a Python library that helps with LLM inference, and maybe training too. Without &lt;em>FlashAttention&lt;/em>, LLMs can easily run into CUDA OOM during inference.&amp;rdquo;&lt;/p>
&lt;p>I answered without really thinking.&lt;/p>
&lt;p>&amp;ldquo;Cool. How does it work?&amp;rdquo;&lt;/p>
&lt;p>&amp;ldquo;&amp;hellip;&amp;rdquo;&lt;/p>
&lt;p>And then I realized: I actually had no idea how &lt;em>FlashAttention&lt;/em> works, or why it reduces the memory complexity of LLM inference.&lt;/p>
&lt;p>That was the moment I decided to really look into &lt;em>FlashAttention&lt;/em>, build my own blog page, and start recording the important technical ideas I learn along the way.&lt;/p>
&lt;p>And also, always remember to be someone who keeps asking why.&lt;/p>
&lt;h1 id="background">Background&lt;/h1>
&lt;p>After searching, I found that &lt;em>FlashAttention&lt;/em> reduce the memory complexity of LLM inference mainly through optimize the calculation of [Softmax]&lt;/p>
&lt;p>In LLMs, softmax appears in more than one place, but the one that matters most for understanding &lt;em>FlashAttention&lt;/em> is the softmax inside the attention mechanism.&lt;/p>
&lt;p>At a high level, self-attention works like this: for each token, the model computes how much attention it should pay to every previous token. Those raw attention scores are first computed, and then softmax turns them into a proper probability distribution.&lt;/p>
&lt;p>More concretely, given query, key, and value matrices $Q$, $K$, and $V$, the attention scores are:&lt;/p>
$$
S = \frac{QK^\top}{\sqrt{d_k}}
$$&lt;p>where $M_{ij} = -\infty$ for positions that should be masked.&lt;/p>
&lt;p>Then softmax is applied &lt;strong>row by row&lt;/strong> over the key dimension:&lt;/p>
$$
P_{ij} = \frac{\exp(S_{ij})}{\sum_j \exp(S_{ij})}
$$&lt;p>This step converts the raw scores into normalized attention weights. Each row now sums to 1, which means the model can interpret them as &amp;ldquo;how much this token attends to each previous token.&amp;rdquo;&lt;/p>
&lt;p>Finally, these attention weights are used to combine the value vectors:&lt;/p>
$$
O = PV
$$&lt;p>So the full attention pipeline is:&lt;/p>
$$
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
$$&lt;h2 id="what-softmax-is-doing-here">What softmax is doing here&lt;/h2>
&lt;p>The role of softmax is simple but important:&lt;/p>
&lt;ul>
&lt;li>It turns arbitrary similarity scores into non-negative &lt;strong>normalized&lt;/strong> weights.&lt;/li>
&lt;li>It amplifies larger scores and suppresses smaller ones.&lt;/li>
&lt;li>It makes the model focus on the most relevant context tokens when forming the output.&lt;/li>
&lt;/ul>
&lt;p>Without softmax, the attention scores would just be raw numbers. The model would have no clean way to interpret them as relative importance.&lt;/p>
&lt;h2 id="why-this-matters-for-flashattention">Why this matters for &lt;em>FlashAttention&lt;/em>&lt;/h2>
&lt;p>This is exactly the expensive part in standard attention.&lt;/p>
&lt;p>To compute attention in the usual way, we often materialize the full score matrix $QK^\top$ and then the full softmax result. If the sequence length is $N$, then this matrix is of size $N \times N$, which leads to very high memory cost.&lt;/p>
&lt;p>&lt;em>FlashAttention&lt;/em> does &lt;strong>not&lt;/strong> change the mathematical definition of softmax. Instead, it computes the same result more carefully, in a tiled / blockwise way, so that it avoids storing the full attention matrix in GPU memory.&lt;/p>
&lt;p>So when people say &lt;em>FlashAttention&lt;/em> reduces memory usage, they do &lt;strong>not&lt;/strong> mean it removes softmax. They mean it computes the same softmax-based attention more efficiently.&lt;/p>
&lt;p>At first glance, a single attention head does not seem too expensive.&lt;/p>
&lt;p>For example, if $L=512$ and each number takes 2 bytes, then one $L \times L$ score matrix costs:&lt;/p>
$$
512 \times 512 \times 2 = 524288 \text{ bytes} \approx 0.5 \text{ MB}
$$&lt;p>However, in a real transformer we usually have multiple heads, multiple layers, and additional tensors such as attention probabilities, Q/K/V activations, and KV cache.&lt;/p>
&lt;p>If we include the head dimension, the score matrix memory for one layer is roughly:&lt;/p>
$$
B \times H \times L^2 \times \text{bytes per element}
$$&lt;p>For example, with $B=1$, $H=32$, $L=4096$, and 2 bytes per element, the attention score matrix alone is already about 1 GB for a single layer.&lt;/p>
&lt;p>And this is only one intermediate tensor. That is why attention becomes a serious memory bottleneck for long sequences.&lt;/p>
&lt;p>&lt;strong>Why do we have to store such a large tensor?&lt;/strong>&lt;/p>
$$
P_{ij} = \frac{\exp(S_{ij})}{\sum_j \exp(S_{ij})}
$$&lt;p>In a naive implementation, we usually store both $S$ and the normalized matrix $P$.
&lt;strong>Why? Because softmax is not just a pointwise operation. For each row, it needs to look at all the elements in that row, compute their relative scale, and normalize them.&lt;/strong>
So the implementation often materializes the whole matrix first, and only then moves on to the next step.&lt;/p>
&lt;p>After that, we compute:&lt;/p>
$$
O = PV
$$&lt;p>So in the standard pipeline, the model often stores large intermediate tensors such as:&lt;/p>
&lt;ul>
&lt;li>the score matrix $S$&lt;/li>
&lt;li>the softmax result $P$&lt;/li>
&lt;li>and sometimes extra temporary values used for numerical stability&lt;/li>
&lt;/ul>
&lt;p>This is why attention becomes so memory-hungry.
The real issue is not just the formula itself, but the way the computation is scheduled and stored in memory.&lt;/p>
&lt;p>And this is exactly where &lt;em>FlashAttention&lt;/em> comes in.&lt;/p>
&lt;p>The key idea of &lt;em>FlashAttention&lt;/em> is: maybe we do not need to store the full $L \times L$ matrix at all.
Instead of materializing the whole attention matrix in GPU memory, &lt;em>FlashAttention&lt;/em> computes attention block by block, while still producing the exact same final result.&lt;/p>
&lt;p>In other words, &lt;em>FlashAttention&lt;/em> does not change the mathematics of attention.
It changes the order of computation, so that the GPU does much less memory movement and avoids storing those massive intermediate tensors.&lt;/p>
&lt;!--
# Flash Attention
The natural question is: if each softmax entry depends on the whole row, how can we avoid storing the full $L \times L$ matrix?
The answer is that we do **not** need to store the whole row forever. We only need enough information to keep softmax numerically correct while we process the row block by block.
## Step 1: Rewrite attention for one row
Let us focus on a single query row $i$. Denote its query vector by $q_i$. For every key $k_j$, we compute a score:
$$
s_{ij} = \frac{q_i k_j^\top}{\sqrt{d_k}}
$$
Then the attention probability is:
$$
p_{ij} = \frac{\exp(s_{ij})}{\sum_{t=1}^{L}\exp(s_{it})}
$$
And the final output vector for this row is:
$$
o_i = \sum_{j=1}^{L} p_{ij} v_j
$$
Substituting the definition of $p_{ij}$ gives:
$$
o_i = \frac{\sum_{j=1}^{L}\exp(s_{ij})v_j}{\sum_{j=1}^{L}\exp(s_{ij})}
$$
This form is already very useful. It tells us that to compute $o_i$, we do not really need the whole softmax matrix itself. What we need is:
- the denominator $\sum_j \exp(s_{ij})$
- the weighted numerator $\sum_j \exp(s_{ij}) v_j$
For numerical stability, we should not use $\exp(s_{ij})$ directly. We first subtract the row maximum:
$$
m_i = \max_j s_{ij}
$$
Then we rewrite the denominator as:
$$
\ell_i = \sum_{j=1}^{L}\exp(s_{ij} - m_i)
$$
and the numerator as:
$$
u_i = \sum_{j=1}^{L}\exp(s_{ij} - m_i)v_j
$$
So the row output becomes:
$$
o_i = \frac{u_i}{\ell_i}
$$
This is the key observation: for one row, we only need to keep track of three things:
- the running max $m_i$
- the running normalizer $\ell_i$
- the running weighted sum $u_i$
## Step 2: Split keys and values into blocks
Now suppose we split the keys and values into blocks:
$$
K = [K^{(1)}, K^{(2)}, \dots, K^{(T)}], \qquad
V = [V^{(1)}, V^{(2)}, \dots, V^{(T)}]
$$
For the same query row $q_i$, block $t$ produces a local score vector:
$$
s_i^{(t)} = \frac{q_i (K^{(t)})^\top}{\sqrt{d_k}}
$$
Inside this block, define:
$$
m_i^{(t)} = \max s_i^{(t)}
$$
$$
\ell_i^{(t)} = \sum_{j \in \text{block } t}\exp(s_{ij} - m_i^{(t)})
$$
$$
u_i^{(t)} = \sum_{j \in \text{block } t}\exp(s_{ij} - m_i^{(t)})v_j
$$
If we only looked at block $t$ in isolation, then its local output would be:
$$
o_i^{(t)} = \frac{u_i^{(t)}}{\ell_i^{(t)}}
$$
The problem is that local softmax values are not the final global softmax values, because the true denominator should include **all** blocks. This is exactly why we need an online update rule.
## Step 3: Online softmax
Suppose we have already processed blocks $1, \dots, t-1$, and we keep three running statistics:
$$
m_i^{\text{old}}, \qquad \ell_i^{\text{old}}, \qquad u_i^{\text{old}}
$$
Now we read the next block $t$, which gives:
$$
m_i^{(t)}, \qquad \ell_i^{(t)}, \qquad u_i^{(t)}
$$
The new global max after merging old information with the new block is:
$$
m_i^{\text{new}} = \max\left(m_i^{\text{old}}, m_i^{(t)}\right)
$$
Now look at the new denominator over all scores seen so far:
$$
\ell_i^{\text{new}} = \sum_{\text{old}} \exp(s_{ij} - m_i^{\text{new}})
+ \sum_{\text{block } t} \exp(s_{ij} - m_i^{\text{new}})
$$
For the old blocks, we factor out the change of reference point from $m_i^{\text{old}}$ to $m_i^{\text{new}}$:
$$
\sum_{\text{old}} \exp(s_{ij} - m_i^{\text{new}})
= \exp(m_i^{\text{old}} - m_i^{\text{new}})
\sum_{\text{old}} \exp(s_{ij} - m_i^{\text{old}})
= \exp(m_i^{\text{old}} - m_i^{\text{new}})\ell_i^{\text{old}}
$$
For the new block:
$$
\sum_{\text{block } t} \exp(s_{ij} - m_i^{\text{new}})
= \exp(m_i^{(t)} - m_i^{\text{new}})
\sum_{\text{block } t} \exp(s_{ij} - m_i^{(t)})
= \exp(m_i^{(t)} - m_i^{\text{new}})\ell_i^{(t)}
$$
Putting them together:
$$
\ell_i^{\text{new}}
= \exp(m_i^{\text{old}} - m_i^{\text{new}})\ell_i^{\text{old}}
+ \exp(m_i^{(t)} - m_i^{\text{new}})\ell_i^{(t)}
$$
The numerator is updated in exactly the same way:
$$
u_i^{\text{new}}
= \exp(m_i^{\text{old}} - m_i^{\text{new}})u_i^{\text{old}}
+ \exp(m_i^{(t)} - m_i^{\text{new}})u_i^{(t)}
$$
and therefore:
$$
o_i^{\text{new}} = \frac{u_i^{\text{new}}}{\ell_i^{\text{new}}}
$$
This is the central derivation behind *FlashAttention*. We never need the full row of $P$, and we never need the full matrix $S$ in GPU memory. We only need the current block and the running statistics $(m_i, \ell_i, u_i)$ for each row.
## Step 4: What FlashAttention is actually optimizing
The important thing is that *FlashAttention* does **not** approximate softmax. The final answer is mathematically the same as standard attention.
What changes is the execution order:
1. load a small block of $Q$, $K$, and $V$
2. compute a block of scores
3. update the running softmax statistics
4. immediately accumulate the output
5. move on to the next block
The large $L \times L$ score matrix never has to be written out to high-bandwidth memory, and the large $L \times L$ probability matrix never has to be materialized there either.
## Memory cost of standard attention
Now let us compare the memory cost more concretely.
For standard attention, a forward pass usually materializes:
- $Q$, $K$, and $V$, each of shape $B \times H \times L \times d$
- the score matrix $S$, of shape $B \times H \times L \times L$
- the probability matrix $P$, also of shape $B \times H \times L \times L$
- the output $O$, of shape $B \times H \times L \times d$
If each element takes $b$ bytes, then the total memory just for these tensors is roughly:
$$
M_{\text{standard}} \approx b\left(4BHLd + 2BHL^2\right)
$$
The important part is the $2BHL^2$ term. Once $L$ becomes large, it dominates everything else.
For example, let:
- $B = 1$
- $H = 32$
- $L = 4096$
- $d = 128$
- $b = 2$ bytes
Then:
$$
\text{size}(Q) = \text{size}(K) = \text{size}(V) = \text{size}(O)
= 1 \times 32 \times 4096 \times 128 \times 2
\approx 32 \text{ MB}
$$
So $Q$, $K$, $V$, and $O$ together are about:
$$
4 \times 32 \text{ MB} = 128 \text{ MB}
$$
Now look at one of the $L \times L$ tensors:
$$
\text{size}(S)
= 1 \times 32 \times 4096^2 \times 2
\approx 1 \text{ GB}
$$
and $P$ costs about another 1 GB.
So standard attention is already at roughly:
$$
128 \text{ MB} + 1 \text{ GB} + 1 \text{ GB}
\approx 2.1 \text{ GB}
$$
for just one layer's major forward tensors.
If $L$ doubles from $4096$ to $8192$, the $L^2$ part becomes four times larger, so the $S$ and $P$ tensors jump from about 1 GB each to about 4 GB each. That is why long-context attention becomes painful so quickly.
## Memory cost of FlashAttention
With *FlashAttention*, we still need to store $Q$, $K$, $V$, and the final output $O$ in GPU memory. So the $BHLd$ part does not disappear.
What disappears from high-bandwidth memory is the need to materialize the full $S$ and $P$ tensors.
Instead, *FlashAttention* keeps:
- blocks of $Q$, $K$, and $V$ in fast on-chip SRAM / shared memory
- a running max $m$ for each row
- a running normalizer $\ell$ for each row
- the running output accumulator
From the perspective of GPU global memory, the main footprint becomes:
$$
M_{\text{flash}} \approx b\left(4BHLd\right) + \text{small row-wise statistics}
$$
The row-wise statistics are only $O(BHL)$, which is tiny compared with $O(BHL^2)$.
Using the same example as above:
- $Q$, $K$, $V$, and $O$ are still about 128 MB in total
- the extra running statistics are tiny compared with gigabytes
- the giant $S$ and $P$ tensors are no longer stored in HBM
So the memory bottleneck changes from:
$$
O(BHL^2)
$$
to:
$$
O(BHLd)
$$
with respect to the main attention activations.
## Why this matters so much in practice
This also explains why the gain is most dramatic when the sequence length is large.
If $L$ is small, then $L^2$ is still manageable, and *FlashAttention* mainly helps by reducing memory traffic.
But when $L$ becomes large, standard attention spends huge memory on $S$ and $P$, while *FlashAttention* keeps memory growth much closer to the size of $Q$, $K$, $V$, and $O$.
There is one subtle point here: in autoregressive decoding with only **one new token**, the attention score shape is more like $1 \times L$ rather than $L \times L$, so the softmax matrix itself is not the main bottleneck. In that setting, the KV cache is often the bigger memory cost.
The biggest win of *FlashAttention* shows up in:
- training
- long-context prefilling
- full-sequence attention where many query positions are processed together
## Final takeaway
The standard implementation stores the full attention score matrix because it computes softmax in a materialize-first way: first build $S$, then build $P$, and only then multiply by $V$.
*FlashAttention* notices that this is not mathematically necessary. If we maintain the right running statistics for each row, we can compute exactly the same softmax and exactly the same output while processing the sequence block by block.
So the real optimization is not "changing softmax into something cheaper." The real optimization is changing **how softmax is executed**, so that the GPU avoids writing and reading enormous $L \times L$ intermediates. --></description></item></channel></rss>