Explorar o código

CUDA: fix overflow in MMA kernel without stream-k (#17939)

Johannes Gäßler hai 1 mes
pai
achega
482211438d

+ 3 - 3
ggml/src/ggml-cuda/fattn-common.cuh

@@ -642,8 +642,8 @@ static __global__ void flash_attn_stream_k_fixup(
     const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
     const int iter_j = (ne01 + (ncols1    - 1)) / ncols1;
 
-    const int kbc0      = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
-    const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -679,7 +679,7 @@ static __global__ void flash_attn_stream_k_fixup(
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+        const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;

+ 3 - 3
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -1380,8 +1380,8 @@ static __global__ void flash_attn_ext_f16(
     const int iter_j = (ne01.z + (ncols1    - 1)) / ncols1;
 
     // kbc == k block continuous, current index in continuous ijk space.
-    int       kbc      = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
-    const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    int       kbc      = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
 
     // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
     // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1401,7 +1401,7 @@ static __global__ void flash_attn_ext_f16(
         const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02* head0);
         const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
         const half   * mask_h = ncols2 == 1 && !mask ? nullptr :
-            (const half  *) (mask + nb33*(sequence % ne33));
+            (const half *) (mask + nb33*(sequence % ne33));
         float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
 
         const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));