说实话,在AI基础设施领域,"灵活"和"高性能"向来是一对冤家。你要么忍受Python的慢速去换取开发效率,要么啃CUDA代码去追求极致性能。

但现在,PyTorch用一组数字狠狠打了这个"二选一"的脸——在Blackwell GPU上,FlexAttention的新后端比原有Triton实现快了1.2到3.2倍。

一千个仓库的"甜蜜负担"

AI配图

FlexAttention从诞生那天起,就背负着一个使命:让研究者用几行Python代码就能实现自定义attention变体,不用碰一行CUDA。

结果呢?超过1000个GitHub仓库采用了它,几十篇论文引用它。 ALiBi、滑动窗口、文档掩码、soft-capping……这些花式attention统统可以用同一个接口搞定。

但问题来了。

用户们一边用得爽,一边开始抱怨:性能天花板太难突破了。

有多难?看看数据就知道了——FlexAttention在Hopper GPU上的吞吐量,只有FlashAttention-3的60%左右。而且这个差距还在扩大,因为两边都在优化,但FlashAttention优化得更快。

一个典型的模式是:研究者用FlexAttention做实验,发现某个attention变体有效,然后——卡住了。性能上不去,只能找专家用底层代码重写。

相当于把负担从研究者转移给了ML工程师。

Blackwell:Tensor Core越强,问题越大

到了新一代Blackwell GPU,情况更夸张。

看这张图,FlexAttention的Triton实现和cuDNN attention之间的差距,从"有点明显"变成了"一道鸿沟":

为什么?

Blackwell的Tensor Core变得更大更快了,但负责指数运算的特殊功能单元(SFU)并没有同步进化。结果就是:softmax的exp()操作现在和矩阵乘法一样贵了。

更关键的是,Blackwell引入了Tensor Memory(TMEM),一种程序员可管理的、靠近Tensor Core的暂存区。数据搬运和矩阵乘法现在都可以完全异步——一个warp发起操作后可以立即干别的事。

这意味着什么?意味着高性能attention需要深度流水线、warp专业化的内核。这种级别的精细编排,Triton编译器根本表达不出来。

FlashAttention-4的解决方案是"ping-pong"策略:在两个tile之间来回切换,用一个tile的矩阵乘法掩盖另一个tile的指数运算延迟。

老实讲,这种级别的底层编排,通用编译器很难自动发现。就像素材里引用的那句话:"当Triton编译器被手写代码打败时,用户几乎无能为力——所有细节都被藏起来了。"

一次跨团队的"联姻"

PyTorch团队做了一个聪明的决定:与其自己从头造轮子,不如直接和FlashAttention团队合作。

Tri Dao等人当时正在开发FlashAttention-4(FA4),一个能充分利用Blackwell硬件的新实现。双方决定联手——不是并行开发两套代码,而是直接扩展FA4,让它成为FlexAttention的后端。

这里有个关键角色:CuTeDSL

这是NVIDIA CUTLASS团队最近发布的Python DSL,让你用Python写高性能CUDA内核。以前必须用CUTLASS C++写的东西,现在可以在Python里搞定了。

PyTorch的Inductor编译器可以把用户的score_modmask_mod函数转换成CuTeDSL代码,然后JIT编译成FA4可以用的形式。

换句话说:用户写Python,编译器生成CuTeDSL,FA4负责执行。

"Flex化"的FlashAttention

具体来说,FlexAttention需要FA4提供两个扩展点:

一是score modification——在forward和backward中注入用户自定义的分数修改逻辑。

实现上,内核把S tile从TMEM加载到寄存器,在计算max/sum和生成P tile的同时应用用户的修改。这样既保持了流水线结构,又不增加额外的阶段。

二是block-sparse iteration——前向和反向都只处理掩码中存在的tile。

这里有个有趣的细节:Blackwell上最小稀疏块大小是256×128,比Triton路径上的128×128更大。因为每个CTA需要处理两个M tile才能让流水线满载。

有意思的是,这些扩展全部用CuTeDSL实现,可以直接内联到FA4的异步流水线里。

数字会说话

来看实打实的结果。

在GB200上,Flash后端相比Triton实现:

  • 前向传播:1.6-3.2倍加速
  • 反向传播:1.85-2.3倍加速

某些情况下,反向传播甚至比cuDNN还快。

对于FlexAttention特有的模式(ALiBi、文档掩码、滑动窗口等):

  • ALiBi:前向1.2-2.1倍,反向1.9-2.9倍
  • 文档掩码:前向最高2.7倍,反向最高3倍
  • 滑动窗口:前向1.4-2.1倍,反向1.8-2.2倍

在Hopper H200上,提升同样明显——ALiBi前向加速1.30-1.54倍,文档掩码前向加速1.41-1.89倍。

他们还用Llama 3 70B在64块H100上做了真实训练验证。结果?Flash后端和Triton后端的训练曲线完美重合,1000步后都收敛到约3.7的最终loss。

当然,还没到完美的时候

我个人觉得,这篇博客最值得称赞的是——它没有回避局限性。

比如,FA4路径目前对块大小的灵活性有限——Hopper上锁定128×128,Blackwell上锁定256×128。

比如,动态标量会触发重编译。如果你的soft_cap值每次调用都变,每个不同的值都会触发一次编译。

比如,captured buffer的反向传播还不支持。如果你有可学习的偏置张量,还得用Triton后端。

还有性能瓶颈:KV维度上的加载可能阻塞流水线;需要pre-softmax scores的反向传播几乎一定会溢出寄存器……

但说实话,这些问题都是"幸福的烦恼"。核心问题已经解决了——你不再需要在灵活性和性能之间做选择了。

从研究到生产,一条路走到底

AI配图

回顾整个故事,你会发现这是一次漂亮的"技术缝合":

PyTorch有灵活的API和编译器基础设施,FlashAttention团队有极致优化的内核,NVIDIA有CuTeDSL这个新工具。三方合力,把研究者的Python代码直接变成了Blackwell上的高性能内核。

FlexAttention的初衷是让研究者快速原型化。现在,原型就是生产代码。

正如素材里那句话:"No more choosing between flexibility and performance."

AI配图

这才是真正的democratization——不是把技术做得简单到傻瓜都能用,而是让专业的人不用在"好用"和"快用"之间反复横跳。

至于MPS支持?评论区已经有人在问了。PyTorch团队,球踢给你们了。


【glm-5锐评】:PyTorch这次算是把"既要又要"的技术债还清了——用编译器的复杂性换用户的简单性,这才是基础设施该干的事。

参考链接:
https://x.com/PyTorch/status/2029617988899381376