Commit 39d8520
authored
[CUDA] GQA CUDA Kernel Fusion and Performance Optimization (#26920)
## Summary
This PR significantly improves GroupQueryAttention (GQA) performance on
CUDA by fusing multiple kernel launches, improving memory access
patterns, and cleaning up sequence length semantics.
## Key Changes
### 1. Fused Kernels for Reduced Launch Overhead
| New Kernel | Operations Fused | Kernels Saved |
|------------|------------------|---------------|
| `UnpackQKVWithRoPEAndAppendKV` | Unpack packed QKV + RoPE Q/K + KV
cache append | 4-5 |
| `ConcatNewToPastKVFused` | K append + V append (separate buffer mode)
| 1 |
| `ConcatKVInPlaceFused` | K append + V append (shared buffer mode) | 1
|
### 2. New `RotaryDispatcher` Template (`rotary_common.cuh`)
Reusable RoPE implementation for fused kernels supporting:
- `float`, `half`, `BFloat16` element types
- `float2`, `float4` vector types
- Interleaved and half-split rotation modes
### 3. Sequence Length Semantics Cleanup
**Before:** Confusing `seqlens_k` / `seqlens_k_buff` with overloaded
meanings.
**After:** Clear separation:
- `past_seq_lens` - offset where new tokens are appended
- `total_seq_lens` - total valid tokens after append
- `padded_seq_lens` - padded length for first prompt masking
### 4. FlashAttention Fast Decode Path
New optimized path for token generation (`sequence_length == 1`, shared
buffer):
- Bypasses `GetSequenceLengths` kernel
- Passes `past_seq_lens` directly to Flash Attention
- Controlled by `ORT_DISABLE_FLASH_DECODE` env var
### 5. Integer Overflow Prevention
All KV cache index calculations use `int64_t` to handle large `batch *
heads * seq * head_size` products.
### 6. BFloat16 Vectorization
Added `float4` (8 elements) vectorized path for BFloat16 in
`ConcatTensorToTensor`.
## Environment Variables
| Variable | Default | Description |
|----------|---------|-------------|
| `ORT_DISABLE_FLASH_DECODE` | `false` | Disable fast decode
optimization |
| `ORT_DISABLE_FUSED_KV` | `false` | Use unfused K/V append kernels |
## Test Changes
### Improved Test Coverage Strategy
Restructured `gqa_cuda_prompt_test_cases()` and
`gqa_cuda_past_test_cases()` to explicitly iterate over kernel code path
parameters:
```python
# NEW: Primary iteration over kernel code paths
for h in h_sizes_to_test:
for packed in packed_opts:
for rotary, rotary_interleaved in rotary_opts:
for share_buffer in share_buffer_opts:
# Secondary params (batch, seq, heads) rotate via modulo
```
| Mode | Before | After |
|------|--------|-------|
| Pipeline | 16 tests, 4/12 combos | 42 tests, 8/12 combos |
| Comprehensive | 81 tests, 4/12 combos | 178 tests, 12/12 combos |
### New Test Parameters
- Added `seqs = [(1, 1)]` for edge case testing
- Added `heads = [(3, 1)]` for non-standard GQA ratios
- Added `h_sizes = [40]` for non-power-of-2 head sizes (tests rotary
skip logic)
### New Test Configurations
- `share_buffer` config option (tests both buffer modes)
- `has_position_ids` testing on CUDA
- Padding prompt parity test
- Fused vs unfused kernel parity tests (`TestFusedKernelParity`)
- Decoding from empty cache test case `(1, 1)`
## Files Changed
**Core:**
- `group_query_attention_impl.cu` - Main implementation refactoring
- `attention_kv_cache.cu` - Fused append kernels
- `flash_api.cc` - Packed QKV stride handling
**New:**
- `rotary_common.cuh` - Reusable RoPE dispatcher
**Tests:**
- `test_gqa.py` - Extended test coverage
## Performance
For decoding or subsequent prompt, we still use original flash attention
kernel, so the performance is almost same as baseline. Here we only show
the results of first prompt.
Below are results of benchmark_gqa.py on H200 GPU. Note that the latency
is measured from onnx model of a GQA node, so the latency includes extra
cost. The kernel speed up can be larger (See profiling results below).
### prompt-sm90-Llama3-8B-b1-h32_8x128-float16
**Configuration**: `batch=1, prompt (past_seq=0), num_heads=32,
kv_heads=8, head_size=128, dtype=float16, gpu=H200`
Dense mean Q, K and V are separated inputs. Packed means Q, K and V are
packed into one input.
| Sequence Length | Dense Base (ms) | Dense Treat (ms) | **Dense
Speedup** | Packed Base (ms) | Packed Treat (ms) | **Packed Speedup** |
| --------------: | --------------: | ---------------: |
:---------------- | ---------------: | ----------------: |
:----------------- |
| 1024 | 0.470 | 0.277 | **1.70x** | 0.468 | 0.320 | **1.46x** |
| 2048 | 1.001 | 0.517 | **1.94x** | 0.990 | 0.590 | **1.68x** |
| 4096 | 2.691 | 1.174 | **2.29x** | 1.504 | 1.242 | **1.21x** |
| 8192 | 7.780 | 2.292 | **3.39x** | 7.933 | 4.004 | **1.98x** |
### prompt-sm90-Llama3-8B-b1-h32_8x128-bfloat16
**Configuration**: `batch=1, prompt (past_seq=0), num_heads=32,
kv_heads=8, head_size=128, dtype=bfloat16, gpu=H200`
| Sequence Length | Dense Base (ms) | Dense Treat (ms) | **Dense
Speedup** | Packed Base (ms) | Packed Treat (ms) | **Packed Speedup** |
| --------------: | --------------: | ---------------: |
:---------------- | ---------------: | ----------------: |
:----------------- |
| 1024 | 0.477 | 0.274 | **1.74x** | 0.486 | 0.332 | **1.46x** |
| 2048 | 1.078 | 0.500 | **2.16x** | 1.087 | 0.601 | **1.81x** |
| 4096 | 2.633 | 1.144 | **2.30x** | 3.017 | 1.282 | **2.35x** |
| 8192 | 7.933 | 2.712 | **2.93x** | 7.933 | 4.003 | **1.98x** |
# Profiling Comparison (Prompt Phase)
**Summary**:
Switching from `flash_fwd_splitkv_kernel` to standard `flash_fwd_kernel`
for the prompt phase (SeqLen=2048) results in a **~3x reduction in
attention kernel latency** and a **~2x improvement in total operator
latency**.
## 1. Packed QKV
**Configuration**: `batch=1, seq_len=2048, past_seq=0, num_heads=32,
kv_heads=8, head_size=128`
| Metric | Baseline | Treatment | Delta |
| :--- | :--- | :--- | :--- |
| **Total Latency** | **639.3 us** | **287.0 us** | **2.23x Speedup** |
| **Attention Kernel** | `flash_fwd_splitkv_kernel`<br>567.10 us |
`flash_fwd_kernel`<br>187.70 us | **3.08x Speedup** |
| **Helper Kernels** | `ConcatNewToPastKV`: 4.71 us |
`UnpackQKVWithRoPEAndAppendKV`: 32.44 us<br>`GetSequenceLengths`: 1.63
us | *Fused ops added* |
> **Note**: The Treatment implementation introduces a fused
`UnpackQKVWithRoPEAndAppendKV` kernel which performs necessary
pre-processing. Despite this added cost (~29 us), the massive gain from
using the efficient `flash_fwd_kernel` instead of
`flash_fwd_splitkv_kernel` yields a significant net speedup.
## 2. Dense (Separated QKV)
**Configuration**: `batch=1, seq_len=2048, past_seq=0, num_heads=32,
kv_heads=8, head_size=128`
| Metric | Baseline | Treatment | Delta |
| :--- | :--- | :--- | :--- |
| **Total Latency** | **0.6468 ms** | **0.3226 ms** | **2.00x Speedup**
|
| **Attention Kernel** | `flash_fwd_splitkv_kernel`<br>567.25 us |
`flash_fwd_kernel`<br> 184.29 us | **3.08x Speedup** |
| **Helper Kernels** | `ConcatNewToPastKV`: 4.68 us |
`RotaryEmbeddingBSNH`: 48.94 us<br>`ConcatNewToPastKVFused`: 13.04
us<br>`GetSequenceLengths`: 1.52 us | *See below* |
> **Note**: Similar to the Packed case, the switch to the standard Flash
Attention forward kernel drives the performance improvement. The
pre-processing is handled by `RotaryEmbeddingBSNH` and
`ConcatNewToPastKVFused` in the treatment.1 parent 5b88e4e commit 39d8520
File tree
16 files changed
+2746
-631
lines changed- onnxruntime
- contrib_ops
- cpu/bert
- cuda
- bert
- flash_attention
- sparse
- test/python/transformers
16 files changed
+2746
-631
lines changedLines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
351 | 351 | | |
352 | 352 | | |
353 | 353 | | |
354 | | - | |
| 354 | + | |
355 | 355 | | |
356 | 356 | | |
357 | 357 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
153 | 153 | | |
154 | 154 | | |
155 | 155 | | |
156 | | - | |
157 | 156 | | |
158 | 157 | | |
159 | 158 | | |
160 | 159 | | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
161 | 172 | | |
162 | 173 | | |
163 | 174 | | |
164 | 175 | | |
165 | | - | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
166 | 179 | | |
167 | 180 | | |
168 | 181 | | |
169 | 182 | | |
170 | 183 | | |
| 184 | + | |
171 | 185 | | |
172 | 186 | | |
173 | 187 | | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
174 | 201 | | |
175 | 202 | | |
176 | 203 | | |
| |||
179 | 206 | | |
180 | 207 | | |
181 | 208 | | |
| 209 | + | |
| 210 | + | |
182 | 211 | | |
183 | 212 | | |
184 | 213 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
918 | 918 | | |
919 | 919 | | |
920 | 920 | | |
921 | | - | |
| 921 | + | |
922 | 922 | | |
923 | 923 | | |
924 | 924 | | |
| |||
0 commit comments