|
@@ -299,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
|
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
|
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
|
|
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
|
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
|
|
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
|
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
|
|
- const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128);
|
|
|
|
|
|
|
+ const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
|
|
|
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
|
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
|
|
if (prec == GGML_PREC_DEFAULT) {
|
|
if (prec == GGML_PREC_DEFAULT) {
|
|
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|