Skip to content

Commit 39d8520

Browse files
authored
[CUDA] GQA CUDA Kernel Fusion and Performance Optimization (#26920)
## Summary This PR significantly improves GroupQueryAttention (GQA) performance on CUDA by fusing multiple kernel launches, improving memory access patterns, and cleaning up sequence length semantics. ## Key Changes ### 1. Fused Kernels for Reduced Launch Overhead | New Kernel | Operations Fused | Kernels Saved | |------------|------------------|---------------| | `UnpackQKVWithRoPEAndAppendKV` | Unpack packed QKV + RoPE Q/K + KV cache append | 4-5 | | `ConcatNewToPastKVFused` | K append + V append (separate buffer mode) | 1 | | `ConcatKVInPlaceFused` | K append + V append (shared buffer mode) | 1 | ### 2. New `RotaryDispatcher` Template (`rotary_common.cuh`) Reusable RoPE implementation for fused kernels supporting: - `float`, `half`, `BFloat16` element types - `float2`, `float4` vector types - Interleaved and half-split rotation modes ### 3. Sequence Length Semantics Cleanup **Before:** Confusing `seqlens_k` / `seqlens_k_buff` with overloaded meanings. **After:** Clear separation: - `past_seq_lens` - offset where new tokens are appended - `total_seq_lens` - total valid tokens after append - `padded_seq_lens` - padded length for first prompt masking ### 4. FlashAttention Fast Decode Path New optimized path for token generation (`sequence_length == 1`, shared buffer): - Bypasses `GetSequenceLengths` kernel - Passes `past_seq_lens` directly to Flash Attention - Controlled by `ORT_DISABLE_FLASH_DECODE` env var ### 5. Integer Overflow Prevention All KV cache index calculations use `int64_t` to handle large `batch * heads * seq * head_size` products. ### 6. BFloat16 Vectorization Added `float4` (8 elements) vectorized path for BFloat16 in `ConcatTensorToTensor`. ## Environment Variables | Variable | Default | Description | |----------|---------|-------------| | `ORT_DISABLE_FLASH_DECODE` | `false` | Disable fast decode optimization | | `ORT_DISABLE_FUSED_KV` | `false` | Use unfused K/V append kernels | ## Test Changes ### Improved Test Coverage Strategy Restructured `gqa_cuda_prompt_test_cases()` and `gqa_cuda_past_test_cases()` to explicitly iterate over kernel code path parameters: ```python # NEW: Primary iteration over kernel code paths for h in h_sizes_to_test: for packed in packed_opts: for rotary, rotary_interleaved in rotary_opts: for share_buffer in share_buffer_opts: # Secondary params (batch, seq, heads) rotate via modulo ``` | Mode | Before | After | |------|--------|-------| | Pipeline | 16 tests, 4/12 combos | 42 tests, 8/12 combos | | Comprehensive | 81 tests, 4/12 combos | 178 tests, 12/12 combos | ### New Test Parameters - Added `seqs = [(1, 1)]` for edge case testing - Added `heads = [(3, 1)]` for non-standard GQA ratios - Added `h_sizes = [40]` for non-power-of-2 head sizes (tests rotary skip logic) ### New Test Configurations - `share_buffer` config option (tests both buffer modes) - `has_position_ids` testing on CUDA - Padding prompt parity test - Fused vs unfused kernel parity tests (`TestFusedKernelParity`) - Decoding from empty cache test case `(1, 1)` ## Files Changed **Core:** - `group_query_attention_impl.cu` - Main implementation refactoring - `attention_kv_cache.cu` - Fused append kernels - `flash_api.cc` - Packed QKV stride handling **New:** - `rotary_common.cuh` - Reusable RoPE dispatcher **Tests:** - `test_gqa.py` - Extended test coverage ## Performance For decoding or subsequent prompt, we still use original flash attention kernel, so the performance is almost same as baseline. Here we only show the results of first prompt. Below are results of benchmark_gqa.py on H200 GPU. Note that the latency is measured from onnx model of a GQA node, so the latency includes extra cost. The kernel speed up can be larger (See profiling results below). ### prompt-sm90-Llama3-8B-b1-h32_8x128-float16 **Configuration**: `batch=1, prompt (past_seq=0), num_heads=32, kv_heads=8, head_size=128, dtype=float16, gpu=H200` Dense mean Q, K and V are separated inputs. Packed means Q, K and V are packed into one input. | Sequence Length | Dense Base (ms) | Dense Treat (ms) | **Dense Speedup** | Packed Base (ms) | Packed Treat (ms) | **Packed Speedup** | | --------------: | --------------: | ---------------: | :---------------- | ---------------: | ----------------: | :----------------- | | 1024 | 0.470 | 0.277 | **1.70x** | 0.468 | 0.320 | **1.46x** | | 2048 | 1.001 | 0.517 | **1.94x** | 0.990 | 0.590 | **1.68x** | | 4096 | 2.691 | 1.174 | **2.29x** | 1.504 | 1.242 | **1.21x** | | 8192 | 7.780 | 2.292 | **3.39x** | 7.933 | 4.004 | **1.98x** | ### prompt-sm90-Llama3-8B-b1-h32_8x128-bfloat16 **Configuration**: `batch=1, prompt (past_seq=0), num_heads=32, kv_heads=8, head_size=128, dtype=bfloat16, gpu=H200` | Sequence Length | Dense Base (ms) | Dense Treat (ms) | **Dense Speedup** | Packed Base (ms) | Packed Treat (ms) | **Packed Speedup** | | --------------: | --------------: | ---------------: | :---------------- | ---------------: | ----------------: | :----------------- | | 1024 | 0.477 | 0.274 | **1.74x** | 0.486 | 0.332 | **1.46x** | | 2048 | 1.078 | 0.500 | **2.16x** | 1.087 | 0.601 | **1.81x** | | 4096 | 2.633 | 1.144 | **2.30x** | 3.017 | 1.282 | **2.35x** | | 8192 | 7.933 | 2.712 | **2.93x** | 7.933 | 4.003 | **1.98x** | # Profiling Comparison (Prompt Phase) **Summary**: Switching from `flash_fwd_splitkv_kernel` to standard `flash_fwd_kernel` for the prompt phase (SeqLen=2048) results in a **~3x reduction in attention kernel latency** and a **~2x improvement in total operator latency**. ## 1. Packed QKV **Configuration**: `batch=1, seq_len=2048, past_seq=0, num_heads=32, kv_heads=8, head_size=128` | Metric | Baseline | Treatment | Delta | | :--- | :--- | :--- | :--- | | **Total Latency** | **639.3 us** | **287.0 us** | **2.23x Speedup** | | **Attention Kernel** | `flash_fwd_splitkv_kernel`<br>567.10 us | `flash_fwd_kernel`<br>187.70 us | **3.08x Speedup** | | **Helper Kernels** | `ConcatNewToPastKV`: 4.71 us | `UnpackQKVWithRoPEAndAppendKV`: 32.44 us<br>`GetSequenceLengths`: 1.63 us | *Fused ops added* | > **Note**: The Treatment implementation introduces a fused `UnpackQKVWithRoPEAndAppendKV` kernel which performs necessary pre-processing. Despite this added cost (~29 us), the massive gain from using the efficient `flash_fwd_kernel` instead of `flash_fwd_splitkv_kernel` yields a significant net speedup. ## 2. Dense (Separated QKV) **Configuration**: `batch=1, seq_len=2048, past_seq=0, num_heads=32, kv_heads=8, head_size=128` | Metric | Baseline | Treatment | Delta | | :--- | :--- | :--- | :--- | | **Total Latency** | **0.6468 ms** | **0.3226 ms** | **2.00x Speedup** | | **Attention Kernel** | `flash_fwd_splitkv_kernel`<br>567.25 us | `flash_fwd_kernel`<br> 184.29 us | **3.08x Speedup** | | **Helper Kernels** | `ConcatNewToPastKV`: 4.68 us | `RotaryEmbeddingBSNH`: 48.94 us<br>`ConcatNewToPastKVFused`: 13.04 us<br>`GetSequenceLengths`: 1.52 us | *See below* | > **Note**: Similar to the Packed case, the switch to the standard Flash Attention forward kernel drives the performance improvement. The pre-processing is handled by `RotaryEmbeddingBSNH` and `ConcatNewToPastKVFused` in the treatment.
1 parent 5b88e4e commit 39d8520

16 files changed

+2746
-631
lines changed

onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ Status CheckCustomAttentionInputs(const T* position_ids,
351351

352352
if (pos_ids_shape[1] < parameters.sequence_length) {
353353
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
354-
"position_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]);
354+
"position_ids dimension 1 must be at least sequence length, got ", pos_ids_shape[1]);
355355
}
356356
}
357357

onnxruntime/contrib_ops/cuda/bert/attention_data.h

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,24 +153,51 @@ struct GroupQueryAttentionData {
153153
const T* value = nullptr;
154154
const T* past_key = nullptr;
155155
const T* past_value = nullptr;
156-
int* seqlens_k = nullptr;
157156
const T* cos_cache = nullptr;
158157
const T* sin_cache = nullptr;
159158
const T* head_sink = nullptr;
160159

160+
// Total sequence length for each batch. It has shape [batch_size].
161+
int* total_seq_lens = nullptr;
162+
163+
// Past sequence length for each batch (i.e., the offset to append new tokens). Shape [batch_size].
164+
// For first prompt: past_seq_lens[b] = 0
165+
// For token generation or subsequent prompt: past_seq_lens[b] = total_seq_lens[b] - sequence_length
166+
int* past_seq_lens = nullptr;
167+
168+
// Padded sequence length for each batch. Shape [batch_size].
169+
// Only used for first prompt: padded_seq_lens[b] = sequence_length
170+
int* padded_seq_lens = nullptr;
171+
161172
// Flash buffers
162173
T* softmax_lse = nullptr;
163174
T* softmax_lse_accum = nullptr;
164175
T* out_accum = nullptr;
165-
int* seqlens_k_buff = nullptr;
176+
177+
// Position IDs from Input
178+
const int64_t* position_ids = nullptr;
166179

167180
// Memory Efficient buffers
168181
T* fmha_buffer = nullptr;
169182
T* unpacked_qkv_buffer = nullptr;
170183
T* rotary_buffer = nullptr;
184+
int64_t* position_ids_buffer = nullptr; // Separate buffer for generated position IDs
171185
T* k = nullptr;
172186
T* v = nullptr;
173187

188+
#ifndef NDEBUG
189+
// Buffer size tracking for debug validation
190+
// Allocated sizes are set during buffer allocation in group_query_attention.cc
191+
// Max used sizes are updated during kernel calls in group_query_attention_impl.cu
192+
// Validated before operator returns to ensure usage exactly matches allocation
193+
size_t unpacked_qkv_buffer_size = 0; // Allocated size
194+
size_t rotary_buffer_size = 0; // Allocated size
195+
size_t position_ids_buffer_size = 0; // Allocated size
196+
mutable size_t unpacked_qkv_max_used = 0; // Max offset accessed (updated by kernels)
197+
mutable size_t rotary_max_used = 0; // Max offset accessed (updated by kernels)
198+
mutable size_t position_ids_max_used = 0; // Max offset accessed (updated by kernels)
199+
#endif
200+
174201
// Output Tensors
175202
T* output = nullptr;
176203
T* present_key = nullptr;
@@ -179,6 +206,8 @@ struct GroupQueryAttentionData {
179206
// Kernel Flags
180207
bool use_flash_attention = false;
181208
bool use_memory_efficient_attention = false;
209+
bool use_flash_attention_fast_decode = false;
210+
bool disable_fused_kv = false;
182211
};
183212

184213
template <typename T>

onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ Status PastPresentBufferShare(int batch_size, int num_heads, int qk_head_size, i
918918
constexpr bool is_new_kv_bnsh_format = true;
919919
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(
920920
batch_size, num_heads, qk_head_size, parameters.max_sequence_length,
921-
data.seqlens_k_total, nullptr, parameters.sequence_length, data.k, data.v, data.present_key, data.present_value,
921+
nullptr, data.seqlens_k_total, parameters.sequence_length, data.k, data.v, data.present_key, data.present_value,
922922
is_past_kv_bnsh_format, is_new_kv_bnsh_format, stream, max_threads_per_block));
923923

924924
data.k = data.present_key;

0 commit comments

Comments
 (0)