Просмотр исходного кода

CUDA: fix race condition in FA vector kernels (#13742)

Johannes Gäßler 7 месяцев назад
Родитель
Сommit
ffd0eae60b
2 измененных файлов с 2 добавлено и 0 удалено
  1. 1 0
      ggml/src/ggml-cuda/fattn-vec-f16.cuh
  2. 1 0
      ggml/src/ggml-cuda/fattn-vec-f32.cuh

+ 1 - 0
ggml/src/ggml-cuda/fattn-vec-f16.cuh

@@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
                 }
             }
             if (__all_sync(0xFFFFFFFF, skip)) {
+                __syncthreads();
                 continue;
             }
 #endif // GGML_USE_HIP

+ 1 - 0
ggml/src/ggml-cuda/fattn-vec-f32.cuh

@@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32(
                 }
             }
             if (__all_sync(0xFFFFFFFF, skip)) {
+                __syncthreads();
                 continue;
             }
 #endif // GGML_USE_HIP