Răsfoiți Sursa

CUDA/HIP: optimize mmv paths taken for HIP devices (#14324)

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
uvos 6 luni în urmă
părinte
comite
0142961a2e
2 a modificat fișierele cu 23 adăugiri și 1 ștergeri
  1. 5 1
      ggml/src/ggml-cuda/common.cuh
  2. 18 0
      ggml/src/ggml-cuda/mmv.cu

+ 5 - 1
ggml/src/ggml-cuda/common.cuh

@@ -263,7 +263,11 @@ static bool fp16_mma_hardware_available(const int cc) {
 }
 
 static bool bf16_mma_hardware_available(const int cc) {
-    return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE;
+    return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
+}
+
+static bool fp32_mma_hardware_available(const int cc) {
+    return GGML_CUDA_CC_IS_CDNA(cc);
 }
 
 // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.

+ 18 - 0
ggml/src/ggml-cuda/mmv.cu

@@ -456,6 +456,11 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
                     return ne11 <= 4;
                 }
                 return ne11 <= 3;
+            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+                if (fp32_mma_hardware_available(cc)) {
+                    return ne11 <= 3;
+                }
+                return ne11 <= 8;
             }
             return ne11 <= 8;
         case GGML_TYPE_F16:
@@ -468,6 +473,14 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
                     return src0_small && ne11 <= 3;
                 }
                 return ne11 <= 8;
+            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+                if (fp16_mma_hardware_available(cc)) {
+                    if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+                        return ne11 <= 5;
+                    }
+                    return ne11 <= 2;
+                }
+                return ne11 <= 8;
             }
             return ne11 <= 8;
         case GGML_TYPE_BF16:
@@ -480,6 +493,11 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
                     return src0_small && ne11 <= 3;
                 }
                 return ne11 <= 8;
+            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+                if (bf16_mma_hardware_available(cc)) {
+                    return ne11 <= 3;
+                }
+                return ne11 <= 8;
             }
             return ne11 <= 8;
         default: