Skip to content

Commit 5d24c95

Browse files
authored
[CUDA] Update Flash Attention Implementation and APIs (#26937)
## Summary This PR updates the Flash Attention implementation in ONNX Runtime, syncing with newer kernel sources in https://github.com/Dao-AILab/flash-attention, and extending the internal API to support additional features required for advanced caching scenarios. It also aligns specific kernels with the official implementation. ## Changes - **Flash Attention Kernels**: Updated/Added Flash Attention forward kernels and headers in `onnxruntime/contrib_ops/cuda/bert/flash_attention/`. - **API Extension**: Updated `mha_fwd` and `mha_fwd_kvcache` in `flash_api.h` and `flash_api.cc` to accept two new optional parameters: - `cache_batch_idx`: Indices to index into the KV cache (support for non-contiguous batch indices). - `leftpad_k`: Support for left-padding in the key sequence. - **Alignment & Fixes**: - **Cleanup**: Removed redundant `kInfinity` definition in `flash_fwd_kernel.h`. - **Includes**: Added missing `<core/providers/cuda/shared_inc/cuda_call.h>` in `flash_fwd_launch_template.h`. - **Integration**: Updated `group_query_attention_impl.cu` to align with the new `mha_fwd_kvcache` signature. - **Build Configuration**: Adjusted `onnxruntime_providers_cpu.cmake` to update the exclusion list for Flash Attention kernels in quick build mode. ## Implementation Details - The `run_mha_fwd` helper now checks if `cache_batch_idx` is provided alongside `k_new` to determine if the split kernel should be forced. - New parameters are propagated through the call stack to the underlying Flash Attention kernels.
1 parent 1ed8fd9 commit 5d24c95

File tree

64 files changed

+1100
-863
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+1100
-863
lines changed

cmake/onnxruntime_providers_cpu.cmake

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS
2525
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cuh"
2626
)
2727

28-
# Quick build mode: Filter out non-hdim128 flash attention kernels for faster development iteration
28+
# Quick build mode: Filter flash attention kernels for faster development iteration.
29+
# - We keep only hdim128 fp16 flash attention kernels in quick build mode.
30+
# - All other listed head dimensions are excluded (e.g., 32, 64, 96, 192, 256).
31+
# - This regex matches both `flash_fwd_hdim*` and `flash_fwd_split_hdim*` kernels.
32+
# If new head dimensions are added or removed, update this list to match the supported set.
2933
if(onnxruntime_QUICK_BUILD)
3034
message(STATUS "Quick build mode enabled: Only building hdim128 fp16 flash attention kernels")
31-
# Filter non-hdim128 kernels
32-
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|160|192|224|256)")
33-
# Filter all bfloat16 kernels (only keep fp16)
35+
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|192|256)")
3436
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*_bf16")
3537
endif()
3638

37-
38-
3939
file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS
4040
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.h"
4141
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc"
Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
/******************************************************************************
22
* Copyright (c) 2023, Tri Dao.
33
******************************************************************************/
4+
45
#pragma once
56

6-
namespace onnxruntime {
7-
namespace flash {
7+
#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h"
8+
namespace FLASH_NAMESPACE {
9+
810
////////////////////////////////////////////////////////////////////////////////////////////////////
911

1012
template <bool Varlen = true>
@@ -17,43 +19,40 @@ struct BlockInfo {
1719
: params.cu_seqlens_k[bidb]),
1820
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr
1921
? params.seqlen_q
20-
: params.cu_seqlens_q[bidb + 1] - sum_s_q)
22+
: params.cu_seqlens_q[bidb + 1] - sum_s_q),
2123
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
2224
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
23-
,
24-
seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr
25-
? params.seqlen_k
26-
: (params.is_seqlens_k_cumulative
27-
? params.cu_seqlens_k[bidb + 1] - sum_s_k
28-
: params.cu_seqlens_k[bidb])),
25+
leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]),
26+
seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr
27+
? params.seqlen_k
28+
: (params.is_seqlens_k_cumulative
29+
? params.cu_seqlens_k[bidb + 1] - sum_s_k
30+
: params.cu_seqlens_k[bidb])) -
31+
leftpad_k),
2932
actual_seqlen_k(params.seqused_k
30-
? params.seqused_k[bidb]
33+
? params.seqused_k[bidb] - leftpad_k
3134
: seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) {
3235
}
3336

3437
template <typename index_t>
35-
__forceinline__ __device__
36-
index_t
37-
q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
38+
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
3839
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
3940
}
4041

4142
template <typename index_t>
42-
__forceinline__ __device__
43-
index_t
44-
k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
45-
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
43+
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
44+
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
4645
}
4746

4847
const int sum_s_q;
4948
const int sum_s_k;
5049
const int actual_seqlen_q;
5150
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
51+
const int leftpad_k;
5252
const int seqlen_k_cache;
5353
const int actual_seqlen_k;
5454
};
5555

5656
////////////////////////////////////////////////////////////////////////////////////////////////////
5757

58-
} // namespace flash
59-
} // namespace onnxruntime
58+
} // namespace FLASH_NAMESPACE

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,11 @@ struct Flash_fwd_params : public Qkv_params {
8080
// array of length b+1 holding starting offset of each sequence.
8181
int* __restrict__ cu_seqlens_q = nullptr;
8282
int* __restrict__ cu_seqlens_k = nullptr;
83+
int* __restrict__ leftpad_k = nullptr;
8384

8485
// If provided, the actual length of each k sequence.
8586
int* __restrict__ seqused_k = nullptr;
8687

87-
int* __restrict__ blockmask = nullptr;
88-
8988
// The K_new and V_new matrices.
9089
void* __restrict__ knew_ptr = nullptr;
9190
void* __restrict__ vnew_ptr = nullptr;
@@ -131,15 +130,17 @@ struct Flash_fwd_params : public Qkv_params {
131130
void* __restrict__ alibi_slopes_ptr = nullptr;
132131
index_t alibi_slopes_batch_stride = 0;
133132

134-
bool unpadded_lse = false;
133+
bool unpadded_lse = false; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
135134
const cudaDeviceProp* dprops = nullptr;
135+
bool seqlenq_ngroups_swapped = false; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
136136
};
137137

138138
////////////////////////////////////////////////////////////////////////////////////////////////////
139139

140-
template <typename T, int Headdim>
140+
template <typename T, int Headdim, bool Is_causal>
141141
void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
142-
template <typename T, int Headdim>
142+
143+
template <typename T, int Headdim, bool Is_causal>
143144
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);
144145

145146
} // namespace flash

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc

Lines changed: 39 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ void set_params_fprop(Flash_fwd_params& params,
4343
bool kv_bsnh = true,
4444
int window_size_left = -1,
4545
int window_size_right = -1,
46-
const bool unpadded_lse = false) {
46+
const bool unpadded_lse = false,
47+
void* cache_batch_idx = nullptr,
48+
void* leftpad_k = nullptr) {
4749
// Set the pointers and strides.
4850
params.q_ptr = q;
4951
params.k_ptr = k;
@@ -147,6 +149,9 @@ void set_params_fprop(Flash_fwd_params& params,
147149

148150
params.is_seqlens_k_cumulative = true;
149151
params.unpadded_lse = unpadded_lse;
152+
153+
params.leftpad_k = static_cast<int*>(leftpad_k);
154+
params.cache_batch_idx = static_cast<int*>(cache_batch_idx);
150155
}
151156

152157
size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) {
@@ -173,11 +178,13 @@ size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads
173178
void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) {
174179
FP16_SWITCH(!params.is_bf16, [&] {
175180
HEADDIM_SWITCH(params.d, [&] {
176-
if (params.num_splits <= 1 && !force_split_kernel) {
177-
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
178-
} else {
179-
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
180-
}
181+
BOOL_SWITCH(params.is_causal, Is_causal_const, [&] {
182+
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
183+
run_mha_fwd_<elem_type, kHeadDim, Is_causal_const>(params, stream);
184+
} else {
185+
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal_const>(params, stream);
186+
}
187+
});
181188
});
182189
});
183190
}
@@ -258,20 +265,6 @@ std::tuple<size_t, size_t, size_t> get_num_splits_and_buffer_sizes(size_t batch_
258265
}
259266
}
260267

261-
// void set_params_alibi(Flash_fwd_params &params, void* alibi_slopes, int batch_size, int num_heads){
262-
// if (alibi_slopes != nullptr) {
263-
// // TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
264-
// // CHECK_DEVICE(alibi_slopes);
265-
// // TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
266-
// // TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads})
267-
// || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
268-
// params.alibi_slopes_ptr = alibi_slopes;
269-
// params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? num_heads : 0; // TODO: flag for bool
270-
// } else {
271-
// params.alibi_slopes_ptr = nullptr;
272-
// }
273-
// }
274-
275268
Status mha_fwd(const cudaDeviceProp& dprops,
276269
cudaStream_t stream,
277270
void* q, // batch_size x seqlen_q x num_heads x head_size
@@ -294,7 +287,9 @@ Status mha_fwd(const cudaDeviceProp& dprops,
294287
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
295288
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
296289
bool kv_bsnh,
297-
int local_window_size) {
290+
int local_window_size,
291+
void* cache_batch_idx,
292+
void* leftpad_k) {
298293
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
299294
const int head_size_rounded = round_multiple(head_size, 32);
300295
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
@@ -322,28 +317,18 @@ Status mha_fwd(const cudaDeviceProp& dprops,
322317
use_smooth_softmax,
323318
kv_bsnh,
324319
local_window_size,
325-
is_causal ? 0 : -1);
320+
is_causal ? 0 : -1,
321+
/*unpadded_lse=*/false,
322+
cache_batch_idx,
323+
leftpad_k);
326324
params.dprops = &dprops;
327-
params.knew_ptr = nullptr;
328-
params.vnew_ptr = nullptr;
329-
params.knew_batch_stride = 0;
330-
params.vnew_batch_stride = 0;
331-
params.knew_row_stride = 0;
332-
params.vnew_row_stride = 0;
333-
params.knew_head_stride = 0;
334-
params.vnew_head_stride = 0;
335325

336326
params.num_splits = num_splits;
337327
if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
338328
params.softmax_lseaccum_ptr = softmax_lse_accum;
339329
params.oaccum_ptr = out_accum;
340-
} else {
341-
params.softmax_lseaccum_ptr = nullptr;
342-
params.oaccum_ptr = nullptr;
343330
}
344331

345-
params.alibi_slopes_ptr = nullptr;
346-
347332
run_mha_fwd(params, stream);
348333
return Status::OK();
349334
}
@@ -408,12 +393,6 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
408393

409394
params.total_q = total_q;
410395
params.dprops = &dprops;
411-
params.num_splits = 0;
412-
params.softmax_lseaccum_ptr = nullptr;
413-
params.oaccum_ptr = nullptr;
414-
params.knew_ptr = nullptr;
415-
params.vnew_ptr = nullptr;
416-
params.alibi_slopes_ptr = nullptr;
417396
if (paged_KV) {
418397
params.block_table = block_table;
419398
params.block_table_batch_stride = max_num_blocks_per_seq;
@@ -440,18 +419,20 @@ bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_hea
440419
// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_.
441420
Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
442421
cudaStream_t stream,
443-
void* q, // batch_size x seqlen_q x num_heads x head_size
444-
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
445-
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
446-
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
447-
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
448-
void* out, // batch_size x seqlen_q x num_heads x head_size
449-
void* softmax_lse, // batch_size x num_heads x seqlen_q
450-
void* seqlens_k_, // batch_size
451-
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
452-
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
453-
void* head_sink, // num_heads
454-
int* block_table, // batch_size x max_num_blocks_per_seq
422+
void* q, // batch_size x seqlen_q x num_heads x head_size
423+
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
424+
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
425+
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
426+
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
427+
void* out, // batch_size x seqlen_q x num_heads x head_size
428+
void* softmax_lse, // batch_size x num_heads x seqlen_q
429+
void* seqlens_k_, // batch_size
430+
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
431+
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
432+
void* cache_batch_idx, // (optional) indices to index into the KV cache
433+
void* leftpad_k, // (optional) batch_size
434+
void* head_sink, // num_heads
435+
int* block_table, // batch_size x max_num_blocks_per_seq
455436
int batch_size,
456437
int num_heads,
457438
int num_heads_k,
@@ -501,7 +482,10 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
501482
use_smooth_softmax,
502483
past_bsnh,
503484
local_window_size,
504-
is_causal ? 0 : -1);
485+
is_causal ? 0 : -1,
486+
/*unpadded_lse=*/false,
487+
cache_batch_idx,
488+
leftpad_k);
505489
params.dprops = &dprops;
506490

507491
if (k_new != nullptr && v_new != nullptr) {
@@ -544,19 +528,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
544528
params.q_head_stride = head_size;
545529
params.knew_head_stride = head_size;
546530
params.vnew_head_stride = head_size;
547-
548-
params.knew_ptr = nullptr;
549-
params.vnew_ptr = nullptr;
550-
} else {
551-
params.seqlen_knew = 0;
552-
params.knew_ptr = nullptr;
553-
params.vnew_ptr = nullptr;
554-
params.knew_batch_stride = 0;
555-
params.vnew_batch_stride = 0;
556-
params.knew_row_stride = 0;
557-
params.vnew_row_stride = 0;
558-
params.knew_head_stride = 0;
559-
params.vnew_head_stride = 0;
560531
}
561532

562533
params.is_seqlens_k_cumulative = seqlens_k_ == nullptr;
@@ -581,7 +552,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
581552
params.oaccum_ptr = nullptr;
582553
}
583554

584-
params.alibi_slopes_ptr = nullptr;
585555
if (paged_KV) {
586556
params.block_table = block_table;
587557
params.block_table_batch_stride = max_num_blocks_per_seq;
@@ -600,7 +570,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
600570
// or if using packed QKV (to ensure correct handling of strided inputs which might be better supported or isolated in split kernel logic).
601571
// Note: if the fused kernel handles packing/rotary/appending, it should pass is_packed_qkv=false to this API (via use_packed_for_fa=false),
602572
// effectively bypassing this check and allowing standard kernels if otherwise eligible.
603-
bool force_split = (k_new != nullptr) || is_packed_qkv;
573+
bool force_split = (k_new != nullptr) || is_packed_qkv || cache_batch_idx != nullptr;
604574

605575
run_mha_fwd(params, stream, force_split);
606576

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ Status mha_fwd(const cudaDeviceProp& dprops,
6161
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
6262
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
6363
bool kv_bsnh = true,
64-
int local_window_size = -1);
64+
int local_window_size = -1,
65+
void* cache_batch_idx = nullptr,
66+
void* leftpad_k = nullptr);
6567

6668
Status mha_varlen_fwd(const cudaDeviceProp& dprops,
6769
cudaStream_t stream,
@@ -91,18 +93,20 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
9193

9294
Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
9395
cudaStream_t stream,
94-
void* q, // batch_size x seqlen_q x num_heads x head_size
95-
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
96-
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
97-
void* k, // batch_size x seqlen_k_new x num_heads_k x head_size
98-
void* v, // batch_size x seqlen_k_new x num_heads_k x head_size
99-
void* out, // batch_size x seqlen_q x num_heads x head_size
100-
void* softmax_lse, // batch_size x num_heads x seqlen_q
101-
void* seqlens_k_, // batch_size
102-
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
103-
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
104-
void* head_sink, // num_heads
105-
int* block_table, // batch_size x max_num_blocks_per_seq
96+
void* q, // batch_size x seqlen_q x num_heads x head_size
97+
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k x head_size, or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
98+
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k x head_size, or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
99+
void* k, // batch_size x seqlen_k_new x num_heads_k x head_size
100+
void* v, // batch_size x seqlen_k_new x num_heads_k x head_size
101+
void* out, // batch_size x seqlen_q x num_heads x head_size
102+
void* softmax_lse, // batch_size x num_heads x seqlen_q
103+
void* seqlens_k_, // batch_size
104+
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
105+
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
106+
void* cache_batch_idx, // (optional) indices to index into the KV cache
107+
void* leftpad_k, // (optional) batch_size
108+
void* head_sink, // num_heads
109+
int* block_table, // batch_size x max_num_blocks_per_seq
106110
int batch_size,
107111
int num_heads,
108112
int num_heads_k,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
/******************************************************************************
2+
* Copyright (c) 2024, Tri Dao.
3+
******************************************************************************/
4+
5+
#include "contrib_ops/cuda/bert/flash_attention/namespace_config.h"
6+
#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
7+
8+
namespace FLASH_NAMESPACE {
9+
10+
template <>
11+
void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params& params, cudaStream_t stream) {
12+
run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
13+
}
14+
15+
} // namespace FLASH_NAMESPACE

0 commit comments

Comments
 (0)