Просмотр исходного кода

vulkan: Fix FA coopmat1 invalid array indexing (#16365)

When computing sinks, the cm1 shader was looping r from 0 to Br rather than
to rows_per_thread. I must have copied this from the scalar path (where it is
correct), and somehow it wasn't causing failures on current drivers.
Jeff Bolz 3 месяцев назад
Родитель
Сommit
0e1f838556
1 измененных файлов с 2 добавлено и 2 удалено
  1. 2 2
      ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

+ 2 - 2
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

@@ -358,8 +358,8 @@ void main() {
     }
 
     if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
 
             float ms = 1.0f;
             float vs = 1.0f;