數(shù)字金融
網(wǎng)絡營銷推廣
電商服務
來源:量子位
時隔一年,F(xiàn)lashAttention-3已經(jīng)全方位升級。訓練速度提升1.5-2倍,F(xiàn)P16下計算吞吐量高達740TFLOPs/s,達理論最大吞吐量75%,更充分利用計算資源,此前只能做到35%。FP8下速度接近1.2PFLOPs/s!同時誤差也進一步減小,F(xiàn)P8下的誤差比標準Attention減少2.6倍。
大模型訓練推理神作,又更新了!
主流大模型都在用的FlashAttention,剛剛升級第三代。
時隔一年,F(xiàn)lashAttention-3已經(jīng)全方位升級。
訓練速度提升1.5-2倍,F(xiàn)P16下計算吞吐量高達740TFLOPs/s,達理論最大吞吐量75%,更充分利用計算資源,此前只能做到35%。
FP8下速度接近1.2PFLOPs/s!
同時誤差也進一步減小,F(xiàn)P8下的誤差比標準Attention減少2.6倍。
而且這一次,不再是一作Tri Dao單打獨斗,F(xiàn)lashAttention-3直接和英偉達、Meta、谷歌等合作,針對最強芯片H100專門做優(yōu)化。
英偉達CUTLASS團隊和cuDNN團隊,都直接為該研究提供支持。
同時和前作一樣,F(xiàn)lashAttention-3也將開源,PyTorch和Hugging Face中都集成。
作者之一Vijay Thakkar激動表示:
曾經(jīng)在FA2發(fā)布時,我就說過這句話。今天,我想再說一次:
看到CUTLASS和CuTe被用來開讓Tensor Core大顯身手的新算法,真的泰褲辣。
前Stable Diffusion老板Emad也非常關注這一進展,他推測使用FlashAttention-3,能將4090的FP8計算吞吐量推升到700+TFLOPs。
充分利用Hopper架構(gòu)特點
自初代發(fā)布以來,F(xiàn)lashAttention已經(jīng)使大模型速度提高了4-8倍,但還有一個遺憾:尚未充分利用現(xiàn)代 GPU。
針對英偉達H100倍后的Hopper架構(gòu)新特性,三代進行了專門優(yōu)化。
整個系列的核心思路,是IO感知優(yōu)化和分塊處理。
作者認為,傳統(tǒng)的注意力機制效率低的原因,在處理長序列時,會出現(xiàn)內(nèi)存訪問操作頻繁,以及算法復雜度指數(shù)級暴增這兩大問題。
FlashAttention通過IO感知優(yōu)化將數(shù)據(jù)從較大但緩慢的高帶寬內(nèi)存(HBM)加載到較小但更快的片上內(nèi)存(SRAM),在SRAM中執(zhí)行計算,減少了內(nèi)存讀寫操作的次數(shù)。
分塊處理則是將輸入序列分成若干小塊,每次只處理一個小塊的數(shù)據(jù)。這種方法使得每次處理的數(shù)據(jù)量減少,從而降低了內(nèi)存使用和計算復雜度。
這樣一來,兩個關鍵問題就得到了解決,這兩大核心思想也在本次的FlashAttention-3中得到了繼承。
但是,第一代的FlashAttention也遺留下了并行性不夠強、工作分區(qū)劃分不合理,以及非矩陣乘法較多(GPU計算單元處理矩陣乘法比非矩陣速度更快)的問題。
針對這一問題,第二代FlashAttention通過重寫softmax,減少了重新縮放操作、邊界檢查和因果屏蔽操作的次數(shù),使得大部分計算集中在矩陣乘法上。
另外,F(xiàn)lashAttention-2引入了序列長度維度上的并行化,并針對工作在線程塊之間的分配進行了優(yōu)化,GPU利用效率更高了。
可以說前兩代當中,作者一直堅持著充分利用硬件特點這一思路,但站在今天的視角來看,對硬件的挖掘仍然不夠充分。
到了這次的FlashAttention-3,由于是直接和英偉達官方合作,對英偉達Hopper架構(gòu)特點的理解更加透徹,軟硬件之間的協(xié)同進一步增強了。
FlashAttention-3的技術(shù)報告顯示,為了充分匹配Hopper架構(gòu),團隊主要做了三方面的技術(shù)升級。
首先,Hopper架構(gòu)的一個重要特點是Tensor Core的異步性,F(xiàn)lashAttention-3針對性地提出了一種異步方式。
具體來說,F(xiàn)lashAttention-3引入了一種“生產(chǎn)者(Producer)-消費者(Consumer)”的編程模型,將注意力的計算劃分為兩個角色。
“生產(chǎn)者”負責將數(shù)據(jù)從HBM異步加載到片上共享內(nèi)存(SMEM)。這個過程主要利用了Hopper GPU的張量內(nèi)存加速器(TMA),可以在不阻塞CUDA核心的情況下進行數(shù)據(jù)傳輸。
消費者直接從共享內(nèi)存讀取數(shù)據(jù),并使用Tensor Core執(zhí)行矩陣乘法等計算密集型任務。由于共享內(nèi)存的訪問延遲遠低于全局內(nèi)存,消費者可以快速獲取所需數(shù)據(jù),提升計算效率。
為了實現(xiàn)角色的劃分,作者引入了warp專門化技術(shù),用不同的warp分別匹配生產(chǎn)者和消費者,讓兩者可以并行執(zhí)行。
這其中利用了Hopper架構(gòu)的動態(tài)warp寄存器分配特性,通過setmaxnreg指令優(yōu)化了寄存器資源的利用。
為了進一步提高GPU的利用率,作者又提出了一種“乒乓調(diào)度”策略,讓一個warp組執(zhí)行矩陣乘法時,另一個warp組執(zhí)行softmax,從而實現(xiàn)計算的重疊。
具體講,F(xiàn)lashAttention-3使用CUDA的同步原語控制不同warp組之間的執(zhí)行順序,讓不同warp組分別執(zhí)行兩種運算,然后像乒乓球一樣交替運行。
第二大技術(shù)特點,是warp組內(nèi)部GEMMs和softmax的重疊,核心奧義是重新安排計算的執(zhí)行順序以提高GPU利用率。
與乒乓調(diào)度不同,這里的計算重排處理的是warp組內(nèi)部的重疊,而乒乓調(diào)度更關注組間協(xié)調(diào)。
實現(xiàn)方式上,F(xiàn)lashAttention-3提出了一種兩階段GEMM-softmax流水線方案,以打破不同操作之間的數(shù)據(jù)依賴。
第一階段,當前迭代(iteration)的softmax操作與下一個迭代的Q·K^T矩陣乘法重疊執(zhí)行。
第二階段,當前迭代的P·V矩陣乘法與下一個迭代的softmax操作重疊執(zhí)行。
通過引入額外的寄存器和共享內(nèi)存緩沖區(qū),F(xiàn)lashAttention-3實現(xiàn)了跨迭代的數(shù)據(jù)傳遞和重用。
在每個迭代中,Q·K^T的結(jié)果首先存儲在名為S_cur的緩沖區(qū)中,用于當前迭代的softmax計算,同時異步執(zhí)行下一個迭代的Q·K^T矩陣乘法,結(jié)果存儲在名為S_next的緩沖區(qū)中。
在執(zhí)行當前迭代的P·V矩陣乘法時,異步執(zhí)行下一個迭代的softmax操作,并更新S_cur和S_next緩沖區(qū)。
第三項更新,是用更低的FP8精度替代FP16。
實際上,降低數(shù)值精度是一種常見的優(yōu)化策略,可以顯著提高GPU的計算吞吐量和能效,Hopper GPU也引入了FP8精度的Tensor Core支持。
但是,直接將注意力計算從FP16轉(zhuǎn)換為FP8可能會引入較大的精度損失。
另外,F(xiàn)P8 Tensor Core對輸入數(shù)據(jù)的布局也有特定的要求(K維度連續(xù)),不幸的是,注意力計算中的輸入數(shù)據(jù)存儲格式(頭維度連續(xù))并不符合這樣的要求。
所以FlashAttention-3首先引入了一系列內(nèi)存布局轉(zhuǎn)換技術(shù),動態(tài)轉(zhuǎn)置V矩陣的塊,改變其連續(xù)方式,從而適配FP8 Tensor Core的布局要求。
在此基礎之上,為了獲得更高的計算精度,F(xiàn)lashAttention-3又采用了分塊量化和非相干處理技術(shù)。
傳統(tǒng)的量化方法通常對整個矩陣使用一個統(tǒng)一的縮放因子(per-tensor quantization),無法很好地適應不同區(qū)域的數(shù)值范圍。
FlashAttention-3則采用了分塊量化(block-wise quantization)的策略,為每個塊單獨設置縮放因子,更好地捕捉局部的數(shù)值分布。
非相干處理(incoherent processing)技術(shù)則是通過隨機正交矩陣對輸入數(shù)據(jù)進行旋轉(zhuǎn),破壞不同塊之間的相干性,減少量化誤差的傳播。
這兩項技術(shù)的結(jié)合使得FlashAttention-3在FP8精度下取得了更高的計算精度,顯著優(yōu)于傳統(tǒng)的量化方法。
結(jié)果,與基于傳統(tǒng)量化方法的FP8實現(xiàn)相比,F(xiàn)lashAttention-3的使得精度提高了2.6倍。
比標準Attention快16倍
以上就是FlashAttention-3在充分研究Hopper架構(gòu)特點后做出的三大更新,針對更新后的表現(xiàn),作者主要進行了3方面測試。
注意力基準測試
消融實驗
FP8注意力準確性測試
首先來看注意力基準測試。
通過改變序列長度(512、1k、……16k),并設置批大小以確??倀oken數(shù)為16k。研究人員將隱藏維度設置為2048,頭維度設置為64、128或258,計算前向傳播、后向傳播。
對比標準Attention、FlashAttention-2、Triton、cuDNN和FlashAttention-3,在H100 80GB SXM5上FP16的運行時間。
FlashAttention-3的前向傳播比FlashAttention-2快1.5-2倍,后向傳播快1.5-1.75倍。
與標準Attention相比,F(xiàn)lashAttention-3的速度快了3-16倍。
對于中長序列(1k以上),F(xiàn)lashAttention-3甚至超過了專門為H100優(yōu)化的cuDNN。
在消融實驗中,通過對非因果FP16 FlashAttention-3進行了2階段WGMMA-softmax流水線和warp特殊化的消融研究,參數(shù)固定為{batch, seqlen, nheads, hdim} = {4, 8448, 16, 128}。
結(jié)果證實,F(xiàn)lashAttention-3改進帶來了顯著加速,從570提升到661。
另外,因為對FlashAttention的數(shù)值誤差感興趣,研究團隊還將FlashAttention-2、FlashAttention-3和標準Attention進行了比較。
為了模擬LLMs中的異常特征和激活,研究團隊生成了Q、K、V的條目,分布為:N(0,1)+N(0,100)?Bernoulli(0.001)
也就是說,每個條目都服從均值為0、標準差為1的正態(tài)分布,但對于0.1%的條目,增加了一個獨立的項,其標準差為10。然后測量均方根誤差(RMSE)。
結(jié)果顯示,在FP16中,由于中間結(jié)果(softmax)保留在FP32中,F(xiàn)lashAttention-2和FlashAttention-3的RMSE比標準Attention減少1.7倍。
FP8的標準Attention使用每個張量的縮放,matmul累加器在FP32中,中間softmax結(jié)果保留在FP16中。由于塊量化和非相干處理,F(xiàn)P8中的FlashAttention-3比這個基線更準確2.6倍。
最后,論文還表示目前工作專注于Hopper架構(gòu),后續(xù)將推廣到其他硬件。
除了英偉達為研究提供了技術(shù)支持外,Meta、Together AI和普林斯頓大學為研究提供了計算支持。
本文來源:量子位,原文標題:《H100利用率飆升至75%!英偉達親自下場FlashAttention三代升級,比標準注意力快16倍》
風險提示及免責條款
市場有風險,投資需謹慎。本文不構(gòu)成個人投資建議,也未考慮到個別用戶特殊的投資目標、財務狀況或需要。用戶應考慮本文中的任何意見、觀點或結(jié)論是否符合其特定狀況。據(jù)此投資,責任自負。