Parcourir la source

[CANN]MUL_MAT optimization (#12382)

Chenguang Li il y a 10 mois
Parent
commit
92a391327e
2 fichiers modifiés avec 6 ajouts et 7 suppressions
  1. 6 2
      ggml/src/ggml-cann/aclnn_ops.cpp
  2. 0 5
      ggml/src/ggml-cann/ggml-cann.cpp

+ 6 - 2
ggml/src/ggml-cann/aclnn_ops.cpp

@@ -2790,10 +2790,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
                 (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
                 output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
                 output_ne_offset);
+            int64_t antiquantGroupSize = 0;
+            if (src0->ne[0] > QK8_0) {
+                antiquantGroupSize = QK8_0;
+            }
 
             ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
                 acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
-                nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
+                nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor,
                 &workspaceSize, &executor));
             if (workspaceAddr == nullptr) {
                 workspaceAddr = workspace_allocator.alloc(workspaceSize);
@@ -2833,7 +2837,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
 
                 ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
                     acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
-                    nullptr, nullptr, nullptr, nullptr, QK8_0,
+                    nullptr, nullptr, nullptr, nullptr, antiquantGroupSize,
                     acl_output_tensor, &workspaceSize, &executor));
                 ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
                     workspaceAddr, workspaceSize, executor, ctx.stream()));

+ 0 - 5
ggml/src/ggml-cann/ggml-cann.cpp

@@ -1689,11 +1689,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
         case GGML_OP_MUL_MAT: {
             switch (op->src[0]->type) {
                 case GGML_TYPE_Q8_0:
-                    // Current groupsize should not be greater than k-1 in
-                    // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize
-                    if (op->src[0]->ne[0] <= QK8_0) {
-                        return false;
-                    }
                 case GGML_TYPE_F16:
                 case GGML_TYPE_F32:
                 case GGML_TYPE_Q4_0: