Просмотр исходного кода

CUDA: fix FA out-of-bounds writes (#7465)

Johannes Gäßler 1 год назад
Родитель
Сommit
38c03478a3
4 измененных файлов с 18 добавлено и 2 удалено
  1. 4 0
      ggml-cuda/fattn-tile-f16.cu
  2. 4 0
      ggml-cuda/fattn-tile-f32.cu
  3. 5 1
      ggml-cuda/fattn-vec-f16.cu
  4. 5 1
      ggml-cuda/fattn-vec-f32.cu

+ 4 - 0
ggml-cuda/fattn-tile-f16.cu

@@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16(
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
 
 
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
         kqsum_j = warp_reduce_sum(kqsum_j);
         kqsum_j = warp_reduce_sum(kqsum_j);
 
 

+ 4 - 0
ggml-cuda/fattn-tile-f32.cu

@@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32(
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
 
 
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
         float kqsum_j = kqsum[j_VKQ_0/nwarps];
         float kqsum_j = kqsum[j_VKQ_0/nwarps];
         kqsum_j = warp_reduce_sum(kqsum_j);
         kqsum_j = warp_reduce_sum(kqsum_j);
 
 

+ 5 - 1
ggml-cuda/fattn-vec-f16.cu

@@ -212,6 +212,10 @@ static __global__ void flash_attn_vec_ext_f16(
 
 
 #pragma unroll
 #pragma unroll
     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ic0 + j_VKQ >= ne01) {
+            break;
+        }
+
         kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
         kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
         kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
         kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
 
 
@@ -223,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
     }
     }
 
 
-    if (parallel_blocks != 1 && tid < ncols) {
+    if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
     }
 #else
 #else

+ 5 - 1
ggml-cuda/fattn-vec-f32.cu

@@ -200,6 +200,10 @@ static __global__ void flash_attn_vec_ext_f32(
 
 
 #pragma unroll
 #pragma unroll
     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ic0 + j_VKQ >= ne01) {
+            break;
+        }
+
         kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
         kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
         kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
         kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
 
 
@@ -211,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
     }
     }
 
 
-    if (parallel_blocks != 1 && tid < ncols) {
+    if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
     }
 }
 }