Просмотр исходного кода

musa: fix compilation warnings in mp_22/31 (#12780)

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
R0CKSTAR 9 месяцев назад
Родитель
Сommit
916c83bfe7

+ 3 - 0
ggml/src/ggml-cuda/cpy.cu

@@ -360,6 +360,9 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
     // copy destination pointers to GPU
     // copy destination pointers to GPU
     CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
     CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
     cuda_graph->graph_cpynode_index = 0; // reset index
     cuda_graph->graph_cpynode_index = 0; // reset index
+#else
+    GGML_UNUSED(cuda_graph); GGML_UNUSED(host_dest_ptrs);
+    GGML_UNUSED(host_dest_ptrs_size); GGML_UNUSED(stream);
 #endif
 #endif
 }
 }
 
 

+ 5 - 5
ggml/src/ggml-cuda/fattn-common.cuh

@@ -62,7 +62,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     T sum = 0.0f;
     T sum = 0.0f;
 
 
 #pragma unroll
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
 
         const int ib    = k_KQ /  QI8_1;
         const int ib    = k_KQ /  QI8_1;
@@ -102,7 +102,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     T sum = 0.0f;
     T sum = 0.0f;
 
 
 #pragma unroll
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
 
         const int ib    = k_KQ /  QI8_1;
         const int ib    = k_KQ /  QI8_1;
@@ -146,7 +146,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     T sum = 0.0f;
     T sum = 0.0f;
 
 
 #pragma unroll
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
 
         const int ib    = k_KQ /  QI8_1;
         const int ib    = k_KQ /  QI8_1;
@@ -193,7 +193,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     T sum = 0.0f;
     T sum = 0.0f;
 
 
 #pragma unroll
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
 
         const int ib    = k_KQ /  QI8_1;
         const int ib    = k_KQ /  QI8_1;
@@ -244,7 +244,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     T sum = 0.0f;
     T sum = 0.0f;
 
 
 #pragma unroll
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
 
         const int ib  = k_KQ / QI8_0;
         const int ib  = k_KQ / QI8_0;

+ 12 - 0
ggml/src/ggml-cuda/fattn-tile-f32.cu

@@ -52,6 +52,18 @@ static __global__ void flash_attn_tile_ext_f32(
     return;
     return;
 #endif // FP16_MMA_AVAILABLE
 #endif // FP16_MMA_AVAILABLE
     if (use_logit_softcap && !(D == 128 || D == 256)) {
     if (use_logit_softcap && !(D == 128 || D == 256)) {
+        GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+        GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+        GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+        GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+        GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+        GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+        GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+        GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+        GGML_UNUSED(ne2); GGML_UNUSED(ne3);
         NO_DEVICE_CODE;
         NO_DEVICE_CODE;
         return;
         return;
     }
     }

+ 15 - 3
ggml/src/ggml-cuda/fattn-vec-f32.cuh

@@ -45,6 +45,18 @@ static __global__ void flash_attn_vec_ext_f32(
 
 
     // Skip unused kernel variants for faster compilation:
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
     if (use_logit_softcap && !(D == 128 || D == 256)) {
+        GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+        GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+        GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+        GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+        GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+        GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+        GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+        GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+        GGML_UNUSED(ne2); GGML_UNUSED(ne3);
         NO_DEVICE_CODE;
         NO_DEVICE_CODE;
         return;
         return;
     }
     }
@@ -114,7 +126,7 @@ static __global__ void flash_attn_vec_ext_f32(
             // Set memory to zero if out of bounds:
             // Set memory to zero if out of bounds:
             if (ncols > 2 && ic0 + j >= ne01) {
             if (ncols > 2 && ic0 + j >= ne01) {
 #pragma unroll
 #pragma unroll
-                for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
                     const int i = i0 + threadIdx.x;
                     const int i = i0 + threadIdx.x;
 
 
                     tmp_q_i32[i] = 0;
                     tmp_q_i32[i] = 0;
@@ -127,7 +139,7 @@ static __global__ void flash_attn_vec_ext_f32(
 
 
             const float * Q_f = (const float *) (Q + j*nb01);
             const float * Q_f = (const float *) (Q + j*nb01);
 #pragma unroll
 #pragma unroll
-            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
                 quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
                 quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
             }
             }
         }
         }
@@ -140,7 +152,7 @@ static __global__ void flash_attn_vec_ext_f32(
             float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
             float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
 
 
 #pragma unroll
 #pragma unroll
-            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
                 const int i = i0 + threadIdx.x;
                 const int i = i0 + threadIdx.x;
 
 
                 Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
                 Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];