Răsfoiți Sursa

CUDA: fix FA tg at long context for CC >= 8.9 (#13852)

Johannes Gäßler 7 luni în urmă
părinte
comite
a68247439b
1 a modificat fișierele cu 2 adăugiri și 2 ștergeri
  1. 2 2
      ggml/src/ggml-cuda/fattn-common.cuh

+ 2 - 2
ggml/src/ggml-cuda/fattn-common.cuh

@@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results(
     __builtin_assume(tid < D);
 
     extern __shared__ float2 meta[];
-    if (tid < 2*parallel_blocks) {
-        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
+    for (int i = tid; i < 2*parallel_blocks; i += D) {
+        ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
     }
 
     __syncthreads();