Browse Source

metal : Add template specialization for mul_mm_id w/ ne20 == 10 (#15799)

Branch: GGMLMetalNE20

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Gabe Goodhart 4 months ago
parent
commit
856ed0947f
2 changed files with 4 additions and 0 deletions
  1. 3 0
      ggml/src/ggml-metal/ggml-metal.m
  2. 1 0
      ggml/src/ggml-metal/ggml-metal.metal

+ 3 - 0
ggml/src/ggml-metal/ggml-metal.m

@@ -407,6 +407,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
@@ -1439,6 +1440,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,       mul_mm_id_map0_f16_ne20_4,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,       mul_mm_id_map0_f16_ne20_4,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,       mul_mm_id_map0_f16_ne20_6,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,       mul_mm_id_map0_f16_ne20_6,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,       mul_mm_id_map0_f16_ne20_8,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,       mul_mm_id_map0_f16_ne20_8,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10,      mul_mm_id_map0_f16_ne20_10,      has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,      mul_mm_id_map0_f16_ne20_16,      has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,      mul_mm_id_map0_f16_ne20_16,      has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,               mul_mm_id_f32_f16,               has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,               mul_mm_id_f32_f16,               has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,               mul_mm_id_f16_f16,               has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,               mul_mm_id_f16_f16,               has_simdgroup_mm);
@@ -3979,6 +3981,7 @@ static int ggml_metal_encode_node(
                             case 4:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
                             case 4:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
                             case 6:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
                             case 6:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
                             case 8:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
                             case 8:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
+                            case 10: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10].pipeline; break;
                             case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
                             case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
                             default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
                             default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
                         }
                         }

+ 1 - 0
ggml/src/ggml-metal/ggml-metal.metal

@@ -7618,6 +7618,7 @@ template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
 
 
 template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
 template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>