edit time: 2026-05-16 18:20:34 做了FlashAttention部分好耶
edit time: 2026-05-16 19:04:38 做了PagedAttention部分好耶
总结:FlashAttention用于降低Attention计算阶段预存的显存消耗(内部计算范畴),
PagedAttention用于高效管理推理的KV Cache(外部推理范畴)。
FlashAttention:
- 问题:标准 Attention 需要存储 N×N 的中间矩阵,显存 O(N²)
- 解决:分块 + 在线 softmax,显存降到 O(N)
- 关键:外循环 Q 块,内循环 K、V 块,在线更新 m、l、O
PagedAttention:
- 问题:预分配连续显存导致大量浪费(内部碎片 + 外部碎片)
- 解决:分块 + 页表映射 + 动态分配
- 扩展:Prefix Cache + 引用计数,共享公共前缀
题目1:中间矩阵的显存问题
在单头情况下(分头时把 d_model 换成 d_model/num_heads )有
Q,K,V∈Rseq_len×d_model
S=softmax(Q×KT/d)×V∈Rseq_len×seq_len
- S矩阵有 105×105=1010 个元素,即序列长度的平方。
- 存储S需要 1010×4B=40GB 。
- 只有 24GB 时放不下。
总之,标准Attention的显存复杂度是序列长度的平方,序列长度变大时显存会爆炸。FlashAttention的目标就是不存储完整的 S 矩阵。
题目2:Softmax的在线计算
softmax是对行做的一个操作,把整行从分数转换成概率分布。
softmax(zi)=∑j∈1,2,…,nezj−mezi−m
其中 m=max{z1,z2,…,zn} 。
下述研究的是上式的“分母求和项”,即 ∑j∈1,2,…,nezj−m 。
但是问题在于,我们并不想立刻遍历大矩阵的整个行得到分母这个求和项,而是希望先对其中的一部分块做操作,得到一个部分的求和。只有在后面的块来了之后,才能逐渐更新这个部分和,直到整行都覆盖到之后得到全局和。
这怎么做呢?就是下述的在线softmax的思想。
维护在线最大值 m,和已看到的所有 ex−m 的和 l。
和 l 中的第一项是对之前计算的 l 乘以修正系数(因为在线最大值变了),第二项是加上在线到来的新的 ex−m 项。
mnew←max{mold,xnew}
lnew←lold⋅emold−mnew+exnew−mnew
初始时 [m,l]=[−∞,0] 。
接下来用一个简单例子来练习上述的更新过程:计算 x=[2,1,4,3] 的在线softmax。
mnew←max{−∞,2}=2
lnew←e2−2=1
看到第一个元素时的 [m,l] 为 [2,1] 。
mnew←max{2,1}=2
lnew←1×e2−2+e1−2=1+e−1≈1.3679
看到前两个元素时的 [m,l] 更新为 [2,1.3679] 。
mnew←max{2,4}=4
lnew←1.3679×e2−4+e4−4≈1.1851
看到前三个元素时的 [m,l] 更新为 [4,1.1851] 。
mnew←max{4,3}=4
lnew←.×e4−4+e3−4≈1.5529
看到全部四个元素时的 [m,l] 更新为 [4,1.5529] 。
忽略掉一些数值上的舍入误差(当然,也可以通过解析解完全相同来验证),会发现 l 的结果恰好等于我们对 x=[2,1,3,4] 希望求的分母求和项 l=e2−4+e1−4+e3−4+e4−4 。
题目3:FlashAttention的核心思想
现在我们有了标准Attention需要乘 N×N 的 S 矩阵(显存爆炸),
Softmax可以在不存全部值的情况下在线计算这两个知识。
那么,FlashAttention的解决方案就是,把 Q,K,V 都分块,每次只加载一小块到高速SRAM中计算,然后在线更新结果。
简化掉一些算法细节,大致的算法如下。
(为了新手理解方便,把 Q 放在了外层、K 放在了内层,
实际上因为 K/V 被多个 Q 块共用,应该把 K 放在外层、Q 放在内层)
注意它不存储完整的 S ,只存当前小块(Br×Bc ,其中 r 和 c 是row和column的意思)。

问题:对下述设置中的 Q1 尝试进行一次完整的迭代更新过程。
- 序列长度 N=4
- 分块大小 Br=2,Bc=2
- 特征维度 d=2,缩放因子 d=2
Q=⎣⎢⎢⎡10100110⎦⎥⎥⎤,K=⎣⎢⎢⎡10101100⎦⎥⎥⎤,V=⎣⎢⎢⎡10100110⎦⎥⎥⎤
分块(按行划分):
- Q1=Q[0:2]=[1001], Q2=Q[2:4]=[1010]
- K1=K[0:2]=[1011], K2=K[2:4]=[1000]
- V1=V[0:2]=[1001], V2=V[2:4]=[1010]
算法状态(对每个 Q 块单独维护):
- O:输出累加器(形状 Br×dv),初始为零矩阵
- m:当前看到的行最大值(形状 Br 的向量),初始为 −∞
- l:当前看到的分母和(形状 Br 的向量),初始为 0
外循环:遍历每个 Qi 块
解决:我们需要处理 Q1 (一次外层循环迭代),
那么分别要计算 K1 和 K2 的两次内层循环迭代。
首先计算 S11=Q1K1T/d 。
Q1K1⊤=[1001][1101]=[1⋅1+0⋅10⋅1+1⋅11⋅0+0⋅10⋅0+1⋅1]=[1101]
S11=1.4141[1101]≈[0.7070.70700.707]
接下来在线更新 m 和 l。
m=[−∞,−∞],l=[0,0]
row_max(S11)=[0.707,0.707]
mnew=max(m,row_max)=[0.707,0.707]
em−mnew=[0,0]
[0.707−0.7070.707−0.7070−0.7070.707−0.707]=[00−0.7070]
e(⋅)=[11e−0.7071]≈[110.4931]
row_sum=[1+0.493,1+1]=[1.493,2.000]
lnew=l⋅em−mnew+row_sum=[0,0]+[1.493,2.000]=[1.493,2.000]
再接下来更新输出 O ,处理 S11 。
初始 O=[0000]
eS11−mnew≈[110.4931],V1=[1001]
[1⋅1+0.493⋅01⋅1+1⋅01⋅0+0.493⋅11⋅0+1⋅1]=[110.4931]
更新 O(注意 O 也要乘修正系数 em−mnew=[0,0])
O←0+[110.4931]=[110.4931]
上述 K1 相关的已经计算完毕,接下来对 K2 执行相同步骤。
K2=[1000],K2⊤=[1000]
Q1K2⊤=[1001][1000]=[1000]
S12≈[0.707000]
当前 m=[0.707,0.707],l=[1.493,2.000]
row_max(S12)=[0.707,0]
mnew=max([0.707,0.707],[0.707,0])=[0.707,0.707]
修正系数 em−mnew=e[0,0]=[1,1]
S12−mnew=[0.707−0.7070−0.7070−0.7070−0.707]=[0−0.707−0.707−0.707]
eS12−mnew=[10.4930.4930.493]
row_sum=[1.493,0.986]
lnew=l⋅[1,1]+row_sum=[1.493,2.000]+[1.493,0.986]=[2.986,2.986]
eS12−mnew≈[10.4930.4930.493],V2=[1010]
[1⋅1+0.493⋅00.493⋅1+0.493⋅01⋅1+0.493⋅00.493⋅1+0.493⋅0]=[10.49310.493]
O←O⋅[1,1]=[110.4931]
O←[110.4931]+[10.49310.493]=[21.4931.4931.493]
接下来归一化 Q1 块的输出
O1=lO=[2/2.9861.493/2.9861.493/2.9861.493/2.986]≈[0.6700.5000.5000.500]
这就是 Q1 中两个 token 经过完整 softmax 注意力后的输出。
补充:离 原论文 还差在哪里?

事实上,
原版算法的 K/V 在外层,Q 在内层。这是为了共用。
原版算法最后才做归一化(diag(linew)−1 乘在前面),写到HBM时 Oi 已经是归一化后的值,下一轮循环可以直接使用(不用再除以 l),我们这里的简化版算法是在 Q 块全部循环结束后统一进行归一化的。
原版算法有堪称实现灵魂的块大小计算公式 Bc=⌈4dM⌉ 和 Br=min(⌈4dM⌉,d) ,保证 Qi,Kj,Vj,Oi 能同时放进SRAM,我们未对这一部分做阐述。
原版用 diag(ℓ)diag(ℓ) 来表达「逐行缩放」,数学上更严谨,也对应代码实现(广播乘法)。
简化用 O * e^(m - m_new) 隐式假设逐行乘,没有强调向量和矩阵的广播语义。
题目4:PagedAttention的显存管理优势
在生成文本时,每次只生成一个新的token,但做这个token的预测时,需要它之前所有token的 K 和 V 。标准做法是把之前算过的 K 和 V 存起来,避免重复计算,这就是KV Cache。
但是,由于每个请求所需要的tokens数未知、有的多有的少,
传统方法是直接给它们分配一块连续的显存,长度设为最大可能值。这需要连续空间、会造成显存浪费、甚至在释放时出现外部碎片,管理不够灵活。
为了解决这一问题,PagedAttention把分页管理的思想引入了KV Cache的管理。
直接看以下题目。
传统方式:每个请求预先分配最大长度 8 tokens 的连续显存
PagedAttention:block size = 4 tokens,按需动态分配
显卡总 KV Cache:32 tokens(即 8 个 block)
当前有 3 个请求:
请求 A:实际使用 3 tokens
请求 B:实际使用 5 tokens
请求 C:实际使用 2 tokens
1.传统方法分配量和浪费量是多少,浪费率是多少?
每个请求分配8 tokens,请求A浪费5 tokens,请求B浪费3 tokens,请求C浪费6 tokens。
总计分配量是24 tokens,总浪费量是14 tokens,浪费率是58.33%。
2.PagedAttention(Block Size = 4)时,每个请求按需分配block。分配量和浪费量是多少,浪费率是多少?
请求A需要一个block,浪费1 tokens。
请求B需要两个block,浪费3 tokens。
请求C需要一个block,浪费2 tokens。
总分配量为四个block对应16 tokens,总浪费量是6 tokens,浪费率是37.5%。
这显现了PagedAttention可以节省分配量,同时降低浪费。
3.请求B从5 tokens增长到7 tokens,传统方式和PagedAttention是否需要重新分配?如果请求B是从5 tokens增长到9 tokens呢?
对于增长到7 tokens,传统方法和PagedAttention均不需要重新分配。因为预先分配的均已经足够用了。
而对于增长到9 tokens,传统方法需要重新分配更大(16 tokens)的连续空间,并且将原来的5 tokens也拷贝到新空间、再写入新增加的4 tokens;
而PagedAttention只需要动态增加1个block。
这显现了PagedAttention不需要连续内存,动态分配、更加灵活。
题目5:Block Table的工作原理
Block Table 实现了逻辑连续、物理不连续的映射,就像操作系统的虚拟内存。
直接看题目。
设定Block Size = 4,请求的 Block Table=[3,1,5] 。
(表示逻辑块0→物理块3,逻辑块1→物理块1,逻辑块2→物理块5)
- 画出这个请求的物理布局(哪个物理块存哪几个 token)
- 要访问 token 索引 9(第 10 个 token),应该去哪个物理块的哪个偏移位置?
- 如果这个请求共有 10 个 tokens(索引 0-9),最后一个 token(索引 9)在哪个物理
块?这个块是否存满了?
1.按顺序排布各个物理块:
物理块0(无请求)
物理块1(token 4,5,6,7)
物理块2(无请求)
物理块3(token 0,1,2,3)
物理块4(无请求)
物理块5(token 8,9,10,11)
2.访问token索引9时应该去逻辑块 9//4=2 ,查表对应物理块5,块内偏移为 9%4=1 。
3.同2可知,最后一个token在物理块5,这个块只存了token 8,9,没有存满。
题目6:Prefix Cache(前缀共享)的引用计数实现
我们发现很多请求共享相同的前缀。例如:
- 系统 prompt:"你是一个 AI 助手"(5 tokens)
- 请求 A:"你是一个 AI 助手,今天天气怎么样"(11 tokens)
- 请求 B:"你是一个 AI 助手,推荐一部电影"(11 tokens)
两个请求的前 5 个 tokens 完全相同。
传统方式是各自存一份 KV Cache,这会浪费 5 tokens × 2 = 10 tokens。
而Prefix Cache则是只存一份共享的prefix,每个请求的Block Table指向共享blocks,用引用计数管理(多个请求共享时,ref_count > 1;没有请求使用时才释放)。
根据上述思想,参考自 nanovllm/engine/block_manager.py 的简化代码如下:
(思想很简单,就是引用计数,共享的只存一份,每次要分配时先看看是不是已经存过了)
class Block:
def __init__(self, block_id):
self.block_id = block_id
self.ref_count = 0 # 引用计数
self.hash = -1 # 内容哈希,用于查找共享
self.token_ids = [] # 实际存储的 tokens
# 分配时,先通过哈希查找是否已有相同内容的 block
hash = compute_hash(token_ids, prefix_hash)
if hash in hash_to_block_id:
block = blocks[block_id]
block.ref_count += 1 # 共享!只是增加引用
else:
# 分配新 block
block = allocate_new_block()
设定:
- block_size = 2
- 基础 prefix:tokens
[A, B, C, D](需要 2 个 blocks)
- 请求 1:
[A, B, C, D, E, F](3 个 blocks)
- 请求 2:
[A, B, C, D, G, H](3 个 blocks)
初始状态:空闲 blocks = [0,1,2,3,4,5,...] 无限多
问题:
- 分配基础 prefix(假设还没有人用过),画出两个 block 的 ref_count
- 分配请求 1,哪些 blocks 是共享的?ref_count 变成多少?
- 分配请求 2,再次共享 prefix,更新 ref_count
- 请求 1 结束,deallocate 时,哪些 blocks 的 ref_count 减少?哪些被真正释放?
- 最终请求 2 的 Block Table 是什么?
1. 分配基础 prefix
- Block0:存
[A, B],hash=h_AB,ref_count=1
- Block1:存
[C, D],hash=h_CD,ref_count=1
2. 分配请求 1(需要 [A,B,C,D,E,F])
请求 1 的 blocks:
- Block0:
[A,B],哈希匹配 → 共享,ref_count=2
- Block1:
[C,D],哈希匹配 → 共享,ref_count=2
- Block2:
[E,F],新 block,ref_count=1
hash_to_block_id 新增:h_EF → Block2
3. 分配请求 2(需要 [A,B,C,D,G,H])
- Block0:
[A,B],哈希匹配 → 共享,ref_count=3
- Block1:
[C,D],哈希匹配 → 共享,ref_count=3
- Block3:
[G,H],新 block,ref_count=1
hash_to_block_id 新增:h_GH → Block3
4. 请求 1 结束,deallocate
遍历请求 1 的 Block Table [0,1,2](逆序):
- Block2:ref_count=1 → 减到 0,释放(从 hash_to_block_id 删除,加入 free list)
- Block1:ref_count=3 → 减到 2(仍被基础 prefix 和请求 2 共享)
- Block0:ref_count=3 → 减到 2(仍被共享)
5. 请求 2 的最终 Block Table
请求 2 的 Block Table = [0, 1, 3]
- Block0(
[A,B]),Block1([C,D])仍在,ref_count=2
- Block3(
[G,H]),ref_count=1
核心优势:通过引用计数,我们节省了 2 blocks × 2 = 4 tokens 的显存。实际场景中,系统 prompt 可能很长(如 1000 tokens),共享效果非常显著。