Просмотр исходного кода

CUDA: broadcasting for FlashAttention mask (#14500)

Johannes Gäßler 6 месяцев назад
Родитель
Сommit
12a81af45f

+ 4 - 1
ggml/src/ggml-cuda/fattn-common.cuh

@@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -851,7 +853,8 @@ void launch_fattn(
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-        mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
+        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
+        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
         Q->nb[1], Q->nb[2], Q->nb[3],
         nb11, nb12, nb13,
         nb21, nb22, nb23,

+ 8 - 4
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
 
         const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
         const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
-        const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+        const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
+            (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
         float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
 
         const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
 
     const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
     const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
-    const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+    const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
+        (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
     float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
 
     const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
     GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
     GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
-    GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+    GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
     GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
     GGML_UNUSED(ne2); GGML_UNUSED(ne3);

+ 6 - 4
ggml/src/ggml-cuda/fattn-tile-f16.cu

@@ -6,7 +6,7 @@
 
 template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1)
+__launch_bounds__(nwarps*WARP_SIZE, 2)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_tile_ext_f16(
         const char * __restrict__ Q,
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
     const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
     const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
     const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *)  mask + ne11*ic0;
+    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
@@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
     GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

+ 6 - 4
ggml/src/ggml-cuda/fattn-tile-f32.cu

@@ -6,7 +6,7 @@
 
 template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1)
+__launch_bounds__(nwarps*WARP_SIZE, 2)
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_tile_ext_f32(
         const char * __restrict__ Q,
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32(
         GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
         GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
         GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
-        GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
         GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32(
     const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
     const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
     const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *)  mask + ne11*ic0;
+    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 

+ 5 - 3
ggml/src/ggml-cuda/fattn-vec-f16.cuh

@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16(
     K += nb12*(blockIdx.z / gqa_ratio);
     V += nb22*(blockIdx.z / gqa_ratio);
 
-    const half * maskh = (const half   *)  mask + ne11*ic0;
+    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
 
     const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
@@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
     GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

+ 6 - 3
ggml/src/ggml-cuda/fattn-vec-f32.cuh

@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32(
         GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
         GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
         GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
-        GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
         GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32(
     Q += nb02* blockIdx.z              + nb01*ic0;
     K += nb12*(blockIdx.z / gqa_ratio);
     V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
-    const half * maskh = (const half   *)  mask + ne11*ic0;
+
+    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
 
     const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
 

+ 8 - 6
ggml/src/ggml-cuda/fattn-wmma-f16.cu

@@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16(
         const int ne12,
         const int ne13,
         const int ne31,
+        const int ne32,
         const int nb31,
+        const int nb32,
         const int nb01,
         const int nb02,
         const int nb03,
@@ -94,11 +96,11 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q + nb02* blockIdx.z              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.z / gqa_ratio));
-    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
-    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
+    const float * Q_f   = (const float *) (Q    + nb02* blockIdx.z              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K    + nb12*(blockIdx.z / gqa_ratio));
+    const half  * V_h   = (const half  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const half2 * mask2 = (const half2 *)  maskh;
 
     const int stride_Q  = nb01 / sizeof(float);
     const int stride_KV = nb11 / sizeof(half);
@@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
     GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
-    GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
     GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);