Переглянути джерело

CUDA: fix FlashAttention on Turing (#13415)

Johannes Gäßler 8 місяців тому
батько
коміт
d8919424f1
1 змінених файлів з 1 додано та 1 видалено
  1. 1 1
      ggml/src/ggml-cuda/fattn-mma-f16.cuh

+ 1 - 1
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -546,7 +546,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
         const int i0_diff = i0_stop - i0_start;
 
-        if (nstages == 1) {
+        if (nstages <= 1) {
             constexpr bool use_cp_async = nstages == 1;
             flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
                 (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);