Commit 5d24c95
authored
[CUDA] Update Flash Attention Implementation and APIs (#26937)
## Summary
This PR updates the Flash Attention implementation in ONNX Runtime,
syncing with newer kernel sources in
https://github.com/Dao-AILab/flash-attention, and extending the internal
API to support additional features required for advanced caching
scenarios. It also aligns specific kernels with the official
implementation.
## Changes
- **Flash Attention Kernels**: Updated/Added Flash Attention forward
kernels and headers in
`onnxruntime/contrib_ops/cuda/bert/flash_attention/`.
- **API Extension**: Updated `mha_fwd` and `mha_fwd_kvcache` in
`flash_api.h` and `flash_api.cc` to accept two new optional parameters:
- `cache_batch_idx`: Indices to index into the KV cache (support for
non-contiguous batch indices).
- `leftpad_k`: Support for left-padding in the key sequence.
- **Alignment & Fixes**:
- **Cleanup**: Removed redundant `kInfinity` definition in
`flash_fwd_kernel.h`.
- **Includes**: Added missing
`<core/providers/cuda/shared_inc/cuda_call.h>` in
`flash_fwd_launch_template.h`.
- **Integration**: Updated `group_query_attention_impl.cu` to align with
the new `mha_fwd_kvcache` signature.
- **Build Configuration**: Adjusted `onnxruntime_providers_cpu.cmake` to
update the exclusion list for Flash Attention kernels in quick build
mode.
## Implementation Details
- The `run_mha_fwd` helper now checks if `cache_batch_idx` is provided
alongside `k_new` to determine if the split kernel should be forced.
- New parameters are propagated through the call stack to the underlying
Flash Attention kernels.1 parent 1ed8fd9 commit 5d24c95
File tree
64 files changed
+1100
-863
lines changed- cmake
- onnxruntime/contrib_ops/cuda/bert
- flash_attention
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
64 files changed
+1100
-863
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
25 | 25 | | |
26 | 26 | | |
27 | 27 | | |
28 | | - | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
29 | 33 | | |
30 | 34 | | |
31 | | - | |
32 | | - | |
33 | | - | |
| 35 | + | |
34 | 36 | | |
35 | 37 | | |
36 | 38 | | |
37 | | - | |
38 | | - | |
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
| |||
Lines changed: 18 additions & 19 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
| 4 | + | |
4 | 5 | | |
5 | 6 | | |
6 | | - | |
7 | | - | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
8 | 10 | | |
9 | 11 | | |
10 | 12 | | |
| |||
17 | 19 | | |
18 | 20 | | |
19 | 21 | | |
20 | | - | |
| 22 | + | |
21 | 23 | | |
22 | 24 | | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
29 | 32 | | |
30 | | - | |
| 33 | + | |
31 | 34 | | |
32 | 35 | | |
33 | 36 | | |
34 | 37 | | |
35 | | - | |
36 | | - | |
37 | | - | |
| 38 | + | |
38 | 39 | | |
39 | 40 | | |
40 | 41 | | |
41 | 42 | | |
42 | | - | |
43 | | - | |
44 | | - | |
45 | | - | |
| 43 | + | |
| 44 | + | |
46 | 45 | | |
47 | 46 | | |
48 | 47 | | |
49 | 48 | | |
50 | 49 | | |
51 | 50 | | |
| 51 | + | |
52 | 52 | | |
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
56 | 56 | | |
57 | 57 | | |
58 | | - | |
59 | | - | |
| 58 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
80 | 80 | | |
81 | 81 | | |
82 | 82 | | |
| 83 | + | |
83 | 84 | | |
84 | 85 | | |
85 | 86 | | |
86 | 87 | | |
87 | | - | |
88 | | - | |
89 | 88 | | |
90 | 89 | | |
91 | 90 | | |
| |||
131 | 130 | | |
132 | 131 | | |
133 | 132 | | |
134 | | - | |
| 133 | + | |
135 | 134 | | |
| 135 | + | |
136 | 136 | | |
137 | 137 | | |
138 | 138 | | |
139 | 139 | | |
140 | | - | |
| 140 | + | |
141 | 141 | | |
142 | | - | |
| 142 | + | |
| 143 | + | |
143 | 144 | | |
144 | 145 | | |
145 | 146 | | |
| |||
Lines changed: 39 additions & 69 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
46 | | - | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
47 | 49 | | |
48 | 50 | | |
49 | 51 | | |
| |||
147 | 149 | | |
148 | 150 | | |
149 | 151 | | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
150 | 155 | | |
151 | 156 | | |
152 | 157 | | |
| |||
173 | 178 | | |
174 | 179 | | |
175 | 180 | | |
176 | | - | |
177 | | - | |
178 | | - | |
179 | | - | |
180 | | - | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
181 | 188 | | |
182 | 189 | | |
183 | 190 | | |
| |||
258 | 265 | | |
259 | 266 | | |
260 | 267 | | |
261 | | - | |
262 | | - | |
263 | | - | |
264 | | - | |
265 | | - | |
266 | | - | |
267 | | - | |
268 | | - | |
269 | | - | |
270 | | - | |
271 | | - | |
272 | | - | |
273 | | - | |
274 | | - | |
275 | 268 | | |
276 | 269 | | |
277 | 270 | | |
| |||
294 | 287 | | |
295 | 288 | | |
296 | 289 | | |
297 | | - | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
298 | 293 | | |
299 | 294 | | |
300 | 295 | | |
| |||
322 | 317 | | |
323 | 318 | | |
324 | 319 | | |
325 | | - | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
326 | 324 | | |
327 | | - | |
328 | | - | |
329 | | - | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | 325 | | |
336 | 326 | | |
337 | 327 | | |
338 | 328 | | |
339 | 329 | | |
340 | | - | |
341 | | - | |
342 | | - | |
343 | 330 | | |
344 | 331 | | |
345 | | - | |
346 | | - | |
347 | 332 | | |
348 | 333 | | |
349 | 334 | | |
| |||
408 | 393 | | |
409 | 394 | | |
410 | 395 | | |
411 | | - | |
412 | | - | |
413 | | - | |
414 | | - | |
415 | | - | |
416 | | - | |
417 | 396 | | |
418 | 397 | | |
419 | 398 | | |
| |||
440 | 419 | | |
441 | 420 | | |
442 | 421 | | |
443 | | - | |
444 | | - | |
445 | | - | |
446 | | - | |
447 | | - | |
448 | | - | |
449 | | - | |
450 | | - | |
451 | | - | |
452 | | - | |
453 | | - | |
454 | | - | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
455 | 436 | | |
456 | 437 | | |
457 | 438 | | |
| |||
501 | 482 | | |
502 | 483 | | |
503 | 484 | | |
504 | | - | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
505 | 489 | | |
506 | 490 | | |
507 | 491 | | |
| |||
544 | 528 | | |
545 | 529 | | |
546 | 530 | | |
547 | | - | |
548 | | - | |
549 | | - | |
550 | | - | |
551 | | - | |
552 | | - | |
553 | | - | |
554 | | - | |
555 | | - | |
556 | | - | |
557 | | - | |
558 | | - | |
559 | | - | |
560 | 531 | | |
561 | 532 | | |
562 | 533 | | |
| |||
581 | 552 | | |
582 | 553 | | |
583 | 554 | | |
584 | | - | |
585 | 555 | | |
586 | 556 | | |
587 | 557 | | |
| |||
600 | 570 | | |
601 | 571 | | |
602 | 572 | | |
603 | | - | |
| 573 | + | |
604 | 574 | | |
605 | 575 | | |
606 | 576 | | |
| |||
Lines changed: 17 additions & 13 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
61 | 61 | | |
62 | 62 | | |
63 | 63 | | |
64 | | - | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
65 | 67 | | |
66 | 68 | | |
67 | 69 | | |
| |||
91 | 93 | | |
92 | 94 | | |
93 | 95 | | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | | - | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
106 | 110 | | |
107 | 111 | | |
108 | 112 | | |
| |||
Lines changed: 15 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
0 commit comments