ソースを参照

vulkan: Fix data race/hang in scalar/cm1 flash attention (#17887)

Jeff Bolz 1 ヶ月 前
コミット
3238b1400c

+ 3 - 0
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

@@ -256,6 +256,9 @@ void main() {
         barrier();
     }
 
+    // prevent race on tmpsh
+    barrier();
+
     // reduce across threads
 
     [[unroll]] for (uint32_t r = 0; r < Br; ++r) {

+ 3 - 0
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

@@ -302,6 +302,9 @@ void main() {
         barrier();
     }
 
+    // prevent race on tmpsh
+    barrier();
+
     // reduce across threads
 
     float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];