Bläddra i källkod

vulkan: fix MMQ quantize_y condition (#17301)

Ruben Ortlam 2 månader sedan
förälder
incheckning
80deff3648
1 ändrade filer med 3 tillägg och 3 borttagningar
  1. 3 3
      ggml/src/ggml-vulkan/ggml-vulkan.cpp

+ 3 - 3
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -6444,7 +6444,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
 
     const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
 
-    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
+    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
 
     // Check for mmq first
     vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
@@ -6731,7 +6731,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
     const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
 
     const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
-    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);
+    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);
 
     vk_pipeline to_fp16_vk_0 = nullptr;
     vk_pipeline to_fp16_vk_1 = nullptr;
@@ -7220,7 +7220,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
 
     const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
 
-    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
+    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
 
     // Check for mmq first
     vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;