从原始Softmax到在线Softmax
撰写了原始的softmax、数值稳定的softmax、在线softmax的公式及其简单代码呈现。
edit time: 2026-04-09 16:37:03
softmax形式推导
我们讨论的起点是原始的softmax形式:
为保证数值稳定,令
则 softmax 可等价改写为数值稳定的softmax形式:
其实就是给分子分母的 先减去 ,然后再使用exp函数,相当于给 计算式的分子分母除以 ,计算得到的 值不变。
这样的好处是,如果某个 值非常大,本来exp一下就溢出了(脑补一下指数函数的图像,比较大时能够狠狠地增大数值),可是我们先减了下来再exp,原理上计算得到的 值不会发生变化。
对其取对数,得到 log-softmax:
其中
称为 log-sum-exp,因此也可写为
接下来,我们考虑如何修改为在线softmax,以支持FlashAttention的分块计算?
核心是:在有了新块中得分的最大值 之后,我们需要把已计算的最大值 更新为
更新后的 可能会影响新的块和已计算的部分的softmax值。因此,需要给已计算过的部分乘以修正系数,并且应该用更新后的 计算新的块。
总之,关键问题是,要用正确的最大值 计算softmax。
详细来看,是以下这样的:
对第 个 query,设第 轮已经处理完前 个列块。定义当前累计最大值为
线性归一化项(其实就是分母那一堆e的求和的高级叫法)为
未归一化输出累计量(其实就是没归一化的分子乘积的高级叫法)为
因此,当前输出可写为
若第 个新块对应的 score 集合为 ,记该块内最大值为
则合并后的最大值更新为
由于归一化基准从 变成了 ,旧块对应的指数项都要乘上修正系数
因此,未归一化输出的递推公式为
同理,线性归一化项满足
最终输出为
当然,若记块内的未归一化输出与归一化项分别为
则也可简写为
总之,再来一遍结论。在线 softmax 的关键就在于:
每次加入新块时,都要先用新的全局最大值
重新对旧块和新块做统一归一化,再更新 与 ,最后得到输出 。
我怀疑FlashAttention网上/GPT的常见讲解有一处推错的地方
省流:结论 没问题,
但是我怀疑有些讲解为了推这个结论用到的 这个假设出发点有问题。
只是个人看法,我也不确定自己这么想对不对,
欢迎大家在评论区或者Github的 Issues 页讨论~
原始推导的问题
在网上,有些讲解对 这个近似结论的推导是出发自 的:
记SRAM占用量为:
若 ,则占用 ,SRAM大小 需满足:
问题:上述 条件及其推出的结果 与后续另一条一定满足的条件 是矛盾的。
原因如下:
当 时,求解取等方程
B=-2d+ \sqrt{4d^2+M}$$ (取负舍去) 即有 $$B = \frac{M}{\sqrt{4d^2+M}+2d}
此时若有 ,才有
这与 或 这一长上下文肯定满足的条件恰好是相反的。
总之,按照原文来说,上述 假设是与原文条件矛盾的推导。
不使用 ,推导得到
应该保留原文的长方形分块 和 ,即按照下述流程推导:
根据每次迭代在SRAM中需要存储的内容,有
故SRAM中需要同时容纳的主要对象对应三类约束:
因此我们正确地得到 满足的条件:
代入约束 ,得到
另一方面,由第二条约束可得
因此综合得到 满足的条件:
验证上述两条条件与 的假设相容:
在FlashAttention原文关注的情况下,通常满足以下假设
并且长上下文时常满足
此时
所以
这时并不存在之前正方形分块下的矛盾,因为原文采用的是长方形分块:
行块高度只有 ;列块宽度可以达到 。
于是中间块大小为
恰好仍能放入SRAM。
接下来就可以继续后续的推导,
把上述的条件 代回我们在 ii 中推导的
即有
等等。
softmax实际代码实现
接下来联系到实际的代码实现,特别是softmax和log-sum-exp的PyTorch实现及其与交叉熵损失和FlashAttention的关系。
标准的数值稳定softmax应该为:
总之就是首先需要计算一个最大值,然后计算减去最大值后的softmax。
import torch
def softmax_stable(logits, dim=-1):
"""
logits: [..., n]
return: [..., n] 概率分布
"""
# 减去最大值防止溢出
max_val = torch.max(logits, dim=dim, keepdim=True)[0]
exp_logits = torch.exp(logits - max_val)
sum_exp = torch.sum(exp_logits, dim=dim, keepdim=True)
return exp_logits / sum_exp
# PyTorch 内置版本(推荐,数值更稳定)
# attn_weights = torch.softmax(scores, dim=-1)
而交叉熵损失则为:
def cross_entropy_loss(logits, labels):
"""
logits: [batch_size, num_classes] 原始输出(未经过 softmax)
labels: [batch_size] 真实类别索引 (0 ~ num_classes-1)
"""
# 先算 log_softmax(数值稳定)
max_val = torch.max(logits, dim=1, keepdim=True)[0]
exp_logits = torch.exp(logits - max_val)
sum_exp = torch.sum(exp_logits, dim=1, keepdim=True)
log_probs = logits - max_val - torch.log(sum_exp) # log(softmax)
# 取出正确类别的 log 概率
batch_size = logits.shape[0]
correct_log_probs = log_probs[range(batch_size), labels]
return -correct_log_probs.mean()
# PyTorch 内置版本(推荐,数值更稳定)
# loss_fn = torch.nn.CrossEntropyLoss() # 内部自动做 log_softmax + NLLLoss
在线Softmax的核心实现则为:
其中 是行row的块大小(即 的块大小,因为是 ),
是列column的块大小。
def flash_attention_online_softmax(Q, K, V, B_r, B_c):
"""
B_r: Q 的块大小,B_c: K/V 的块大小
Q, K, V: [n, d] 假设 batch=1, heads=1 简化
"""
n, d = Q.shape
O = torch.zeros_like(Q) # 输出累加器
l = torch.zeros(n, 1) # 指数和统计量
m = torch.full((n, 1), -float('inf')) # 最大值统计量
# 外层循环:遍历 KV 块
for j in range(0, n, B_c):
Kj = K[j:j+B_c] # [B_c, d]
Vj = V[j:j+B_c] # [B_c, d]
# 内层循环:遍历 Q 块
for i in range(0, n, B_r):
Qi = Q[i:i+B_r] # [B_r, d]
Oi = O[i:i+B_r] # [B_r, d]
mi = m[i:i+B_r] # [B_r, 1]
li = l[i:i+B_r] # [B_r, 1]
# 计算当前块的得分矩阵 [B_r, B_c]
Sij = torch.matmul(Qi, Kj.T) / (d ** 0.5)
# 当前块的最大值(每行)
mij_new = torch.max(Sij, dim=1, keepdim=True)[0] # [B_r, 1]
# 当前块的指数和(以块内最大值归一化)
Pij = torch.exp(Sij - mij_new) # [B_r, B_c]
lij_new = torch.sum(Pij, dim=1, keepdim=True) # [B_r, 1]
# 当前块的加权输出(块内)
Oij_new = torch.matmul(Pij, Vj) # [B_r, d]
# ========== 在线 softmax 合并 ==========
# 更新全局最大值
m_new = torch.maximum(mi, mij_new) # [B_r, 1]
# 缩放旧统计量
scale_old = torch.exp(mi - m_new) # [B_r, 1]
scale_new = torch.exp(mij_new - m_new) # [B_r, 1]
# 更新指数和
l_new = scale_old * li + scale_new * lij_new # [B_r, 1]
# 更新输出(关键:旧的输出需要按比例缩放)
O_new = (scale_old * Oi * li + scale_new * Oij_new) / l_new
# 写回
O[i:i+B_r] = O_new
m[i:i+B_r] = m_new
l[i:i+B_r] = l_new
return O
手推小例子
- 标准 Softmax 数值稳定计算
给定 logits = [2.0, 1.0, 0.1],写出数值稳定的 softmax 计算过程。
解:
max = 2.0
logits - max = [0.0, -1.0, -1.9]
exp = [e^0, e^{-1}, e^{-1.9}] = [1.0000, 0.3679, 0.1496]
sum_exp = 1.5175
softmax = [1/1.5175, 0.3679/1.5175, 0.1496/1.5175] = [0.6590, 0.2424, 0.0986]
验证和 = 1.0000
答:p = [0.6590, 0.2424, 0.0986]
- 交叉熵损失计算
设模型输出 logits = [2.0, 1.0, 0.1],真实标签为类别 0 ,计算交叉熵损失。
解:
softmax 结果 p = [0.6590, 0.2424, 0.0986]
正确类别的概率 p_c = 0.6590
L = -log(p_c) = -log(0.6590) = 0.4170
答:L = 0.4170
- 标准 Attention 中的 Softmax
已知 Q, K, V 如下(n=3, d=2),计算 Attention 输出。
Q = [[1.0, 0.5],
[0.5, 1.0],
[0.0, 1.0]]
K = [[0.8, 0.6],
[0.3, 0.9],
[1.0, 0.2]]
V = [[0.5, 1.0],
[0.8, 0.3],
[0.2, 0.7]]
sqrt(d) = sqrt(2) = 1.4142
第一步:计算 S = QK^T / sqrt(d)
QK^T =
行0·K^T: [1.0*0.8+0.5*0.6=1.1, 1.0*0.3+0.5*0.9=0.75, 1.0*1.0+0.5*0.2=1.1]
行1·K^T: [0.5*0.8+1.0*0.6=1.0, 0.5*0.3+1.0*0.9=1.05, 0.5*1.0+1.0*0.2=0.7]
行2·K^T: [0.0*0.8+1.0*0.6=0.6, 0.0*0.3+1.0*0.9=0.9, 0.0*1.0+1.0*0.2=0.2]
S = 除以 1.4142:
行0: [1.1/1.4142=0.7778, 0.75/1.4142=0.5303, 1.1/1.4142=0.7778]
行1: [1.0/1.4142=0.7071, 1.05/1.4142=0.7425, 0.7/1.4142=0.4950]
行2: [0.6/1.4142=0.4243, 0.9/1.4142=0.6364, 0.2/1.4142=0.1414]
第二步:每行 softmax
行0: max=0.7778, 减后=[0, -0.2475, 0], exp=[1, 0.7808, 1], sum=2.7808, p=[0.3596, 0.2808, 0.3596]
行1: max=0.7425, 减后=[-0.0354, 0, -0.2475], exp=[0.9652, 1, 0.7808], sum=2.7460, p=[0.3515, 0.3642, 0.2843]
行2: max=0.6364, 减后=[-0.2121, 0, -0.4950], exp=[0.8088, 1, 0.6098], sum=2.4186, p=[0.3344, 0.4135, 0.2521]
第三步:O = P × V
O[0] = 0.3596×[0.5,1.0] + 0.2808×[0.8,0.3] + 0.3596×[0.2,0.7]
= [0.1798,0.3596] + [0.2246,0.0842] + [0.0719,0.2517] = [0.4763, 0.6955]
O[1] = 0.3515×[0.5,1.0] + 0.3642×[0.8,0.3] + 0.2843×[0.2,0.7]
= [0.1758,0.3515] + [0.2914,0.1093] + [0.0569,0.1990] = [0.5241, 0.6598]
O[2] = 0.3344×[0.5,1.0] + 0.4135×[0.8,0.3] + 0.2521×[0.2,0.7]
= [0.1672,0.3344] + [0.3308,0.1241] + [0.0504,0.1765] = [0.5484, 0.6350]
答:O = [[0.4763, 0.6955], [0.5241, 0.6598], [0.5484, 0.6350]]
- 在线 Softmax 合并(一个 query,两个块)
完整 scores = [2.0, 1.5, 3.0, 0.5]
V = [[1.0, 0.0], [0.5, 0.5], [0.0, 1.0], [0.8, 0.2]]
(1) 标准方法(全局 softmax)
max = 3.0
exp(减max后) = [e^{-1}, e^{-1.5}, e^{0}, e^{-2.5}] = [0.3679, 0.2231, 1.0000, 0.0821]
sum_exp = 1.6731
probs = [0.2199, 0.1334, 0.5977, 0.0491]
output = 0.2199×[1,0] + 0.1334×[0.5,0.5] + 0.5977×[0,1] + 0.0491×[0.8,0.2]
= [0.2199,0] + [0.0667,0.0667] + [0,0.5977] + [0.0393,0.0098]
= [0.3259, 0.6742]
(2) 在线方法
初始化:m = -∞, l = 0, o = [0,0]
块1:s=[2.0,1.5], v=[[1,0],[0.5,0.5]]
m1=2.0, exp(减2)=[1,0.6065], l1=1.6065
o1 = (1/1.6065)×[1,0] + (0.6065/1.6065)×[0.5,0.5] = [0.6225,0] + [0.1887,0.1887] = [0.8112,0.1887]
合并:m_new = max(-∞,2)=2, scale_old=0, scale_new=1
l = 0 + 1×1.6065 = 1.6065
o = (0 + 1×[0.8112,0.1887]×1.6065) / 1.6065 = [0.8112,0.1887]
块2:s=[3.0,0.5], v=[[0,1],[0.8,0.2]]
m2=3.0, exp(减3)=[1,0.0821], l2=1.0821
o2 = (1/1.0821)×[0,1] + (0.0821/1.0821)×[0.8,0.2] = [0,0.9241] + [0.0607,0.0152] = [0.0607,0.9393]
合并:m_prev=2, m_new=max(2,3)=3
scale_old = e^{2-3}=0.3679, scale_new = e^{3-3}=1
l = 0.3679×1.6065 + 1×1.0821 = 0.5910 + 1.0821 = 1.6731
o = (0.3679×[0.8112,0.1887]×1.6065 + 1×[0.0607,0.9393]×1.0821) / 1.6731
= (0.3679×[1.3030,0.3031] + [0.0657,1.0165]) / 1.6731
= ([0.4793,0.1115] + [0.0657,1.0165]) / 1.6731
= [0.5450, 1.1280] / 1.6731 = [0.3257, 0.6742]
答:在线方法得到 [0.3257, 0.6742],与标准方法的 [0.3259, 0.6742] 一致(误差 < 1e-4)。