Przeglądaj źródła

opencl: refine condition for kqv mm (#17392)

lhez 2 miesięcy temu
rodzic
commit
8e9ddba610
1 zmienionych plików z 17 dodań i 3 usunięć
  1. 17 3
      ggml/src/ggml-opencl/ggml-opencl.cpp

+ 17 - 3
ggml/src/ggml-opencl/ggml-opencl.cpp

@@ -6895,9 +6895,23 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
     cl_context context = backend_ctx->context;
 
     if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
-        if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){
-            ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
-            return;
+        if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) {
+            // For KQ
+            if (ggml_is_permuted(src0) && ggml_is_permuted(src1) &&
+                nb00 <= nb02 &&
+                nb02 <= nb01 &&
+                nb01 <= nb03 &&
+                nb10 <= nb12 &&
+                nb12 <= nb11 &&
+                nb11 <= nb13) {
+                ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
+                return;
+            }
+            // For KQV
+            if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+                ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
+                return;
+            }
         }
     }