Просмотр исходного кода

cuda : update supports_op for matrix multiplication (#8245)

slaren 1 год назад
Родитель
Сommit
0e0590adab
2 измененных файлов с 31 добавлено и 17 удалено
  1. 30 17
      ggml/src/ggml-cuda.cu
  2. 1 0
      tests/test-backend-ops.cpp

+ 30 - 17
ggml/src/ggml-cuda.cu

@@ -2711,27 +2711,40 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             {
-                struct ggml_tensor * a;
-                struct ggml_tensor * b;
+                struct ggml_tensor * a = op->src[0];
                 if (op->op == GGML_OP_MUL_MAT) {
-                    a = op->src[0];
-                    b = op->src[1];
-                } else {
-                    a = op->src[2];
-                    b = op->src[1];
-                }
-                if (a->ne[3] != b->ne[3]) {
-                    return false;
-                }
-                ggml_type a_type = a->type;
-                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
-                    a_type == GGML_TYPE_IQ1_S   || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S   ||
-                    a_type == GGML_TYPE_IQ1_M   || a_type == GGML_TYPE_IQ2_S  || a_type == GGML_TYPE_IQ4_XS) {
-                    if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
+                    struct ggml_tensor * b = op->src[1];
+                    if (a->ne[3] != b->ne[3]) {
                         return false;
                     }
                 }
-                return true;
+                switch (a->type) {
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                    case GGML_TYPE_Q2_K:
+                    case GGML_TYPE_Q3_K:
+                    case GGML_TYPE_Q4_K:
+                    case GGML_TYPE_Q5_K:
+                    case GGML_TYPE_Q6_K:
+                    case GGML_TYPE_Q8_K:
+                    case GGML_TYPE_IQ1_M:
+                    case GGML_TYPE_IQ1_S:
+                    case GGML_TYPE_IQ2_S:
+                    case GGML_TYPE_IQ2_XS:
+                    case GGML_TYPE_IQ2_XXS:
+                    case GGML_TYPE_IQ3_S:
+                    case GGML_TYPE_IQ3_XXS:
+                    case GGML_TYPE_IQ4_NL:
+                    case GGML_TYPE_IQ4_XS:
+                        return true;
+                    default:
+                        return false;
+                }
             } break;
         case GGML_OP_GET_ROWS:
             {

+ 1 - 0
tests/test-backend-ops.cpp

@@ -2052,6 +2052,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
         GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
         GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+        GGML_TYPE_BF16,
     };
 
     // unary ops