Просмотр исходного кода

HIP: fix flash_attn_stream_k_fixup warning (#11604)

Johannes Gäßler 11 месяцев назад
Родитель
Сommit
6eecde3cc8
2 измененных файлов с 12 добавлено и 2 удалено
  1. 10 0
      ggml/src/ggml-cuda/fattn-common.cuh
  2. 2 2
      ggml/src/ggml-cuda/softmax.cu

+ 10 - 0
ggml/src/ggml-cuda/fattn-common.cuh

@@ -516,6 +516,12 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
         nullptr;
 }
 
+// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+
 template<int D, int ncols, int KQ_stride> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
@@ -614,6 +620,10 @@ static __global__ void flash_attn_stream_k_fixup(
     }
 }
 
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
+
 template<int D, int parallel_blocks> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)

+ 2 - 2
ggml/src/ggml-cuda/softmax.cu

@@ -18,7 +18,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
 #ifdef __clang__
 #pragma clang diagnostic push
 #pragma clang diagnostic ignored "-Wpass-failed"
-#endif
+#endif // __clang__
 template <bool use_shared, int ncols_template, int block_size_template, typename T>
 static __global__ void soft_max_f32(
         const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
@@ -126,7 +126,7 @@ static __global__ void soft_max_f32(
 }
 #ifdef __clang__
 #pragma clang diagnostic pop
-#endif
+#endif // __clang__
 
 static __global__ void soft_max_back_f32(
         const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {