Bladeren bron

CUDA: fix race condition in MMQ ids_dst (#13294)

Johannes Gäßler 8 maanden geleden
bovenliggende
commit
8afbd96818
1 gewijzigde bestanden met toevoegingen van 7 en 0 verwijderingen
  1. 7 0
      ggml/src/ggml-cuda/mmq.cuh

+ 7 - 0
ggml/src/ggml-cuda/mmq.cuh

@@ -2636,6 +2636,7 @@ static __global__ void mul_mat_q(
 
 
         ids_dst_shared[j] = j;
         ids_dst_shared[j] = j;
     }
     }
+    __syncthreads();
 
 
     // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
     // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
 #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
 #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
@@ -2664,6 +2665,7 @@ static __global__ void mul_mat_q(
                 return;
                 return;
             }
             }
 
 
+            // __syncthreads(); // There is no previous tile that could cause a race condition.
 #pragma unroll
 #pragma unroll
             for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
             for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
                 const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
                 const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2674,6 +2676,7 @@ static __global__ void mul_mat_q(
 
 
                 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
                 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
             }
             }
+            __syncthreads();
         }
         }
 
 
         offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
         offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
@@ -2740,6 +2743,7 @@ static __global__ void mul_mat_q(
                 continue;
                 continue;
             }
             }
 
 
+            __syncthreads();
 #pragma unroll
 #pragma unroll
             for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
             for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
                 const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
                 const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2750,6 +2754,7 @@ static __global__ void mul_mat_q(
 
 
                 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
                 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
             }
             }
+            __syncthreads();
         }
         }
 
 
         offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
         offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
@@ -2805,6 +2810,7 @@ static __global__ void mul_mat_q(
         }
         }
 
 
         // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
         // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
+        __syncthreads();
 #pragma unroll
 #pragma unroll
         for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
         for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
             const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
             const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2815,6 +2821,7 @@ static __global__ void mul_mat_q(
 
 
             ids_dst_shared[j] = j;
             ids_dst_shared[j] = j;
         }
         }
+        __syncthreads();
     }
     }
 
 
     offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
     offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));