Преглед изворни кода

CUDA: fix Volta FlashAttention logic (#11615)

Johannes Gäßler пре 11 месеци
родитељ
комит
21c84b5d2d
2 измењених фајлова са 3 додато и 2 уклоњено
  1. 1 1
      ggml/src/ggml-cuda/fattn-wmma-f16.cu
  2. 2 1
      ggml/src/ggml-cuda/fattn.cu

+ 1 - 1
ggml/src/ggml-cuda/fattn-wmma-f16.cu

@@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
                     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
                     break;
                 // case 256:
-                //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                //     ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
                 //     break;
                 default:
                     GGML_ABORT("fatal error");

+ 2 - 1
ggml/src/ggml-cuda/fattn.cu

@@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         return;
     }
 
-    if (!new_mma_available(cc)) {
+    if (!fp16_mma_available(cc)) {
         if (prec == GGML_PREC_DEFAULT) {
             if (Q->ne[1] <= 8) {
                 ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
@@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
     if (cc == GGML_CUDA_CC_VOLTA) {
         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+        return;
     }
 
     ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);