|
|
@@ -2790,10 +2790,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
|
|
|
(char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
|
|
|
output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
|
|
|
output_ne_offset);
|
|
|
+ int64_t antiquantGroupSize = 0;
|
|
|
+ if (src0->ne[0] > QK8_0) {
|
|
|
+ antiquantGroupSize = QK8_0;
|
|
|
+ }
|
|
|
|
|
|
ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
|
|
|
acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
|
|
|
- nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
|
|
|
+ nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor,
|
|
|
&workspaceSize, &executor));
|
|
|
if (workspaceAddr == nullptr) {
|
|
|
workspaceAddr = workspace_allocator.alloc(workspaceSize);
|
|
|
@@ -2833,7 +2837,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
|
|
|
|
|
|
ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
|
|
|
acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
|
|
|
- nullptr, nullptr, nullptr, nullptr, QK8_0,
|
|
|
+ nullptr, nullptr, nullptr, nullptr, antiquantGroupSize,
|
|
|
acl_output_tensor, &workspaceSize, &executor));
|
|
|
ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
|
|
|
workspaceAddr, workspaceSize, executor, ctx.stream()));
|