Procházet zdrojové kódy

CUDA: fix pointer incrementation in FA (#14916)

Johannes Gäßler před 5 měsíci
rodič
revize
946b1f6859

+ 4 - 5
ggml/src/ggml-cuda/fattn-vec-f16.cuh

@@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
     K     += blockIdx.y*D * nb11;
     V     += blockIdx.y*D * nb21;
     maskh += blockIdx.y*D;
-    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
+             // Increment pointers after each loop:
+             K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
+
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         if (mask) {
@@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
             }
         }
 
-        K     += gridDim.y*D * nb11;
-        V     += gridDim.y*D * nb21;
-        maskh += gridDim.y*D;
-
         __syncthreads();
     }
 

+ 4 - 5
ggml/src/ggml-cuda/fattn-vec-f32.cuh

@@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
     K     += blockIdx.y*D * nb11;
     V     += blockIdx.y*D * nb21;
     maskh += blockIdx.y*D;
-    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
+             // Increment pointers after each loop:
+             K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
+
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         if (mask) {
@@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
             }
         }
 
-        K     += gridDim.y*D * nb11;
-        V     += gridDim.y*D * nb21;
-        maskh += gridDim.y*D;
-
         __syncthreads();
     }