Explorar el Código

CUDA: fix race conditions FlashAttention kernels (#13438)

Johannes Gäßler hace 8 meses
padre
commit
0208355f42
Se han modificado 2 ficheros con 3 adiciones y 0 borrados
  1. 2 0
      ggml/src/ggml-cuda/fattn-mma-f16.cuh
  2. 1 0
      ggml/src/ggml-cuda/fattn-vec-f16.cuh

+ 2 - 0
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -874,6 +874,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             }
             }
         }
         }
 
 
+        __syncthreads();
+
         // Write back combined meta data:
         // Write back combined meta data:
 #pragma unroll
 #pragma unroll
         for (int imeta = 0; imeta < nmeta; ++imeta) {
         for (int imeta = 0; imeta < nmeta; ++imeta) {

+ 1 - 0
ggml/src/ggml-cuda/fattn-vec-f16.cuh

@@ -168,6 +168,7 @@ static __global__ void flash_attn_vec_ext_f16(
     for (int j = 0; j < ncols; ++j) {
     for (int j = 0; j < ncols; ++j) {
         KQ[j*D + tid] = -HALF_MAX_HALF;
         KQ[j*D + tid] = -HALF_MAX_HALF;
     }
     }
+    __syncthreads();
 
 
     half2 VKQ[ncols] = {{0.0f, 0.0f}};
     half2 VKQ[ncols] = {{0.0f, 0.0f}};