Procházet zdrojové kódy

CUDA: fix FA tg at long context for CC >= 8.9 (#13852)

Johannes Gäßler před 8 měsíci
rodič
revize
a68247439b
1 změnil soubory, kde provedl 2 přidání a 2 odebrání
  1. 2 2
      ggml/src/ggml-cuda/fattn-common.cuh

+ 2 - 2
ggml/src/ggml-cuda/fattn-common.cuh

@@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results(
     __builtin_assume(tid < D);
 
     extern __shared__ float2 meta[];
-    if (tid < 2*parallel_blocks) {
-        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
+    for (int i = tid; i < 2*parallel_blocks; i += D) {
+        ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
     }
 
     __syncthreads();