فهرست منبع

CUDA/HIP: Fix fattn-vec-* when device warp size is not 32 (#12315)

When fattn-wmma was ported over to warp64 various bits that also touch fattn-vec where converted to
selectable warp size, however the fattn-vec kernels dont work with 64 wide warps for now, so we need
to avoid launching them with parameters for warp64
uvos 10 ماه پیش
والد
کامیت
34c961b181
2فایلهای تغییر یافته به همراه26 افزوده شده و 33 حذف شده
  1. 22 30
      ggml/src/ggml-cuda/fattn-common.cuh
  2. 4 3
      ggml/src/ggml-cuda/fattn-wmma-f16.cu

+ 22 - 30
ggml/src/ggml-cuda/fattn-common.cuh

@@ -52,12 +52,11 @@ typedef half (*vec_dot_KQ_f16_t)(
 typedef float (*vec_dot_KQ_f32_t)(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
 
-template<typename T, int D>
+template<typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
@@ -93,12 +92,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     return sum;
 }
 
-template<typename T, int D>
+template<typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
@@ -138,12 +136,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     return sum;
 }
 
-template<typename T, int D>
+template<typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
@@ -186,12 +183,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     return sum;
 }
 
-template<typename T, int D>
+template<typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
@@ -238,12 +234,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     return sum;
 }
 
-template <typename T, int D>
+template <typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
@@ -272,12 +267,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     return sum;
 }
 
-template <typename T, int D>
+template <typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
 
     const half2 * K_h2 = (const half2 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_q8);
     GGML_UNUSED(Q_ds_v);
 
@@ -480,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
     return x[i];
 }
 
-template <int D>
+template <int D, int warp_size = WARP_SIZE>
 constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
-    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
-        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
-        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
-        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
-        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
-        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size> :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size> :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size> :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
         nullptr;
 }
 
-template <int D>
+template <int D, int warp_size = WARP_SIZE>
 constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
-    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
-        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
-        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
-        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
-        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
-        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D, warp_size> :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D, warp_size> :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D, warp_size> :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> :
         nullptr;
 }
 
@@ -681,7 +675,8 @@ static void on_no_fattn_vec_case(const int D) {
 template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
-    const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
+    const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
+    const int warp_size = WARP_SIZE
 ) {
     constexpr int ncols = ncols1 * ncols2;
 
@@ -704,8 +699,6 @@ void launch_fattn(
 
     GGML_ASSERT(Q->ne[3] == 1);
 
-    const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
-
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
     const int id  = ggml_cuda_get_device();
@@ -805,7 +798,6 @@ void launch_fattn(
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
     GGML_ASSERT(block_dim.x % warp_size == 0);
-    GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
     fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
         (const char *) Q->data,
         K_data,

+ 4 - 3
ggml/src/ggml-cuda/fattn-wmma-f16.cu

@@ -469,6 +469,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
     constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
     const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
     const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
 
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
@@ -485,7 +486,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
             fattn_kernel = flash_attn_ext_f16<
                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
         }
-        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
         return;
     }
     if (2*blocks_num_pb1 < 2*nsm) {
@@ -500,7 +501,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
             fattn_kernel = flash_attn_ext_f16<
                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
         }
-        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
         return;
     }
     constexpr int parallel_blocks = 1;
@@ -514,7 +515,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
         fattn_kernel = flash_attn_ext_f16<
             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
     }
-    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
 }
 
 void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {