Classic softmax attention aka Softmax(Q K^T/sqrt(d_k))V consists of two matrix multiplications.
This means QK^T=O and then softmax(O/sqrt(d_k)V.
The matrix O is quadratic with respect to the number of input tokens. Writing the O matrix to main memory is bound by the maximum bandwidth of your memory.
Then it has to be read out again to be multiplied against V.
What flash attention does is change the algorithm. Flash attention is numerically similar to softmax attention, but not equivalent. The changed algorithm allows you to fuse the independent kernels.
Instead of writing out the O matrix to main memory, its softmax is calculated against V immediately. The double memory roundtrip is now gone. This in itself does not change the fact that both softmax attention and flash attention are quadratic with respect to the input, but it sure as hell improves the speed of "prefill".
If you tile the Q, K, V matrices into n blocks each, you will still have to load O(n^2) blocks.
But here is the thing. Matrix multiplication is an operation with a significant amount of shared data. This means the multipliers calculating the dot products are being fed from the same flip flops, or the data is shifted around via a systolic array. You end up in a situation with an insignificant memory load, but a massive amount of arithmetic.
In addition to that, you have all the tokens already, so the MLPs at the end of the layer can be processed as GEMM instead of GEMV.
This is why "prefill" is compute intensive instead of memory intensive.
During token generation, you need to perform attention for the next token, with all the tokens already in the KV cache. You load n entries from the KV cache, then do GEMV on the MLP and you have to do this over and over again in a sequential fashion. This means that memory bandwidth is the deciding factor for token generation.
Now here is a caveat: if SRAM is limited Vs your TOPS, then it is possible that even flash attention is memory bound, but for a different reason. It's memory bound, because the maximum tile size that can be held in SRAM can be processed faster than it takes to load it from system memory or VRAM and you are performing a quadratic amount of tile loading operations. This is only noticeable near the extreme top end of context lengths between 32k and 128k tokens.
Let's just summarize the FlashAttention into the following: Att(i) computation without FA runs in
O(seq_len*dk + seq_len^2)
whereas Att(i) computation with FA runs in
O(seq_len^2*dk^2/SRAM_size)
Q, K, V computation remains the same. And ATTN(0,n)*Wo also remains the same.
In a smaller model, with N=12, D=768, dk=64, seq_len=1k, SRAM=32KB, ..., FA optimization would roughly translate to 0.5M vs 4.5M per-head(att(i)). So ~10x improvement but in the grand scheme of things, in per-attention-layer it becomes ~91M vs ~45M so ~2x of net improvement.
> This is why "prefill" is compute intensive instead of memory intensive.
Yes, I think I agree and I have corrected myself elsewhere in the thread. The original thought that I actually wanted to convey in my initial comment which was somehow lost throughout the discussion is that - prefill/training will benefit from the FlashAttention/MLA but the inference will not. I can agree that the formulation "only when memory access time dominates the compute in attention implementation" was wrong.
> During token generation ... memory bandwidth is the deciding factor for token generation.
LLama3-70B MLP layer roughly takes 1 TFLOPS and 0.6 GB of bandwidth for 1024 tokens. Assuming that 1023 entries are taken from a KV-cache, attention layer computation for a single token will take ~0.6 GFLOPS and ~0.2 GB of bandwidth. To load the rest of the values from KV-cache at FP16 precision, it will take us 1023*0.1MB or ~1 GB.
So, ~1 TFLOPS and ~1 GB of bandwidth per each Transformers layer. On hardware such as H100, this still looks like a compute-bound problem to me. OTOH on the CPU with 15 TFLOPS of compute but only <1TB/s of memory bandwidth, it becomes memory-bound problem. Or no?
For Llama 3 70B, batch size = 1, each MLP layer roughly takes 1x8192x26872x2 + 1x8192x26872x2 + 1x26872x8192x2 FLOPS ~= 1.31 GFLOPS, instead of ~1 TFLOPS.
Since the number differs by roughly 1024x, maybe you forgot that you just need to work on the last decoded token for MLP, too? Because you don't need hidden state for previous tokens in Attn now.
True. I wrote some software that does these calculations for me besides the ones I already have on the paper. I confused two different graphs.
So, the final number would be ~0.6 GFLOPS (self-attention across heads) + ~0.15 GFLOPS (attention) + ~1 GFLOPS (ffwd) which in total give or take is ~2 GFLOPS per-layer.
Bandwidth-wise, the ~1GB number I previously gave was also wrong (llama3-70B has 8 KV heads). Now, with more precise calculations that figure is ~0.6 GB per-layer.
So, at batch_size=1, FP8 precision, 1024 tokens, during the decode phase with KV-cache, we need ~2GFLOPS of compute and ~0.6GB of bandwidth per each layer. Still looks compute-bound to me.
H100 has 3.3TB/s HBM bandwidth on paper, and ~1000TFLOPS bf16 compute on paper. That's 1:300. 0.6GB vs ~2GFLOPS is 1:3. Tell me how is this compute bound?
(also, your number, even after accounting for GQA, is still off. You usually can't store kvcache in fp8.)
Classic softmax attention aka Softmax(Q K^T/sqrt(d_k))V consists of two matrix multiplications.
This means QK^T=O and then softmax(O/sqrt(d_k)V.
The matrix O is quadratic with respect to the number of input tokens. Writing the O matrix to main memory is bound by the maximum bandwidth of your memory.
Then it has to be read out again to be multiplied against V.
What flash attention does is change the algorithm. Flash attention is numerically similar to softmax attention, but not equivalent. The changed algorithm allows you to fuse the independent kernels.
Instead of writing out the O matrix to main memory, its softmax is calculated against V immediately. The double memory roundtrip is now gone. This in itself does not change the fact that both softmax attention and flash attention are quadratic with respect to the input, but it sure as hell improves the speed of "prefill".
If you tile the Q, K, V matrices into n blocks each, you will still have to load O(n^2) blocks.
But here is the thing. Matrix multiplication is an operation with a significant amount of shared data. This means the multipliers calculating the dot products are being fed from the same flip flops, or the data is shifted around via a systolic array. You end up in a situation with an insignificant memory load, but a massive amount of arithmetic.
In addition to that, you have all the tokens already, so the MLPs at the end of the layer can be processed as GEMM instead of GEMV.
This is why "prefill" is compute intensive instead of memory intensive.
During token generation, you need to perform attention for the next token, with all the tokens already in the KV cache. You load n entries from the KV cache, then do GEMV on the MLP and you have to do this over and over again in a sequential fashion. This means that memory bandwidth is the deciding factor for token generation.
Now here is a caveat: if SRAM is limited Vs your TOPS, then it is possible that even flash attention is memory bound, but for a different reason. It's memory bound, because the maximum tile size that can be held in SRAM can be processed faster than it takes to load it from system memory or VRAM and you are performing a quadratic amount of tile loading operations. This is only noticeable near the extreme top end of context lengths between 32k and 128k tokens.