소스 검색

CUDA: fix FA tg at long context for CC >= 8.9 (#13852)

Johannes Gäßler 8 달 전
부모
커밋
a68247439b
1개의 변경된 파일2개의 추가작업 그리고 2개의 파일을 삭제
  1. 2 2
      ggml/src/ggml-cuda/fattn-common.cuh

+ 2 - 2
ggml/src/ggml-cuda/fattn-common.cuh

@@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results(
     __builtin_assume(tid < D);
 
     extern __shared__ float2 meta[];
-    if (tid < 2*parallel_blocks) {
-        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
+    for (int i = tid; i < 2*parallel_blocks; i += D) {
+        ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
     }
 
     __syncthreads();