|
|
@@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
K += blockIdx.y*D * nb11;
|
|
|
V += blockIdx.y*D * nb21;
|
|
|
maskh += blockIdx.y*D;
|
|
|
- for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
|
|
+ for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
|
|
|
+ // Increment pointers after each loop:
|
|
|
+ K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
|
|
+
|
|
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
|
|
|
|
|
if (mask) {
|
|
|
@@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- K += gridDim.y*D * nb11;
|
|
|
- V += gridDim.y*D * nb21;
|
|
|
- maskh += gridDim.y*D;
|
|
|
-
|
|
|
__syncthreads();
|
|
|
}
|
|
|
|