|
|
@@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- if (!new_mma_available(cc)) {
|
|
|
+ if (!fp16_mma_available(cc)) {
|
|
|
if (prec == GGML_PREC_DEFAULT) {
|
|
|
if (Q->ne[1] <= 8) {
|
|
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
|
@@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
|
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
|
|
if (cc == GGML_CUDA_CC_VOLTA) {
|
|
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
|
+ return;
|
|
|
}
|
|
|
|
|
|
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|