DualKV:共用提示 FlashAttention 核心消除 RL 訓練中的 KV 重複計算,30B MoE 加速 3.82 倍
arXiv · 2026-05-18 · 論文 2605.15422
arXiv 論文 2605.15422「DualKV」提出一種針對強化學習(RL)訓練場景的 FlashAttention 核心最佳化。在 GRPO、DAPO 等方法中,同一個提示會生成多個回應序列,標準 FlashAttention 對這些序列重複計算提示的 KV 表示,DualKV 以自訂 CUDA 核心消除這個冗餘。
背景
GRPO(Group Relative Policy Optimization)等 RL 訓練方法的典型做法是:從同一個提示取樣 N 個回應,計算每個回應相對於其他回應的獎勵,並以此更新策略。在前向/反向傳播時,需要對 N×(提示長度 + 回應長度)個 token 計算注意力——其中提示部分被重複了 N 次。大 rollout 數和長上下文讓這個冗餘成為主要瓶頸。
核心改動
因果遮罩的數學性質提供了最佳化的基礎:在 decoder-only 模型中,提示的 KV 表示在所有從同一提示生成的序列中完全相同——因果遮罩讓提示 token 只能看到自身之前的 token,與回應內容無關。DualKV 利用這個不變性:
- 共用提示區域:只計算一次,由所有回應序列共享
- 各序列回應區域:每個回應獨立計算
實作機制分兩層:自訂 CUDA 前向/反向核心(Fused Kernels)在單次 kernel launch 中同時處理兩個 KV 區域;資料重新打包(Data Repacking)將輸入從 N×(P+R) token 轉換為 P+N×R token,讓節省效果延伸至整個模型。這個方法數學上與標準注意力完全等價,不是近似計算。
實際效果
| 配置 | 加速倍率 | MFU(前→後) |
|---|---|---|
| Qwen3-8B GRPO(32 rollout, 8K ctx) | 1.63–2.09× | 36% → 76% |
| DAPO | 2.47× | —→ 77% |
| 30B MoE 策略更新 | 3.82× | — |
| 30B MoE 端對端 | 3.38× | — |
影響範圍
DualKV 的適用範圍是所有使用「單一提示多回應採樣」模式的 RL 訓練方法——GRPO、DAPO、PPO 的某些變體。對大型模型(30B 以上)的效益更顯著,因為提示佔總序列長度的比例更高時,冗餘計算的節省空間也更大。MFU 從 36% 提升至 76% 代表同樣的硬體能做超過兩倍的有效計算。程式碼與論文同步發布,可直接整合進現有的 vLLM 或自訂 RL 訓練管線。
原始來源:arXiv:2605.15422