Просмотр исходного кода

metal : fix kernel requirements (#15983)

* metal : fix kernel requirements

ggml-ci

* cont : fix supports_op

* cont : fix supports_op for ARGMAX
Georgi Gerganov 4 месяцев назад
Родитель
Сommit
a14bd35014
1 измененных файлов с 8 добавлено и 7 удалено
  1. 8 7
      ggml/src/ggml-metal/ggml-metal.m

+ 8 - 7
ggml/src/ggml-metal/ggml-metal.m

@@ -1219,10 +1219,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                    ssm_conv_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                    ssm_conv_f32,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                    ssm_scan_f32,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,              ssm_scan_f32_group,              true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                    ssm_scan_f32,                    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,              ssm_scan_f32_group,              has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,                   rwkv_wkv6_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,                   rwkv_wkv6_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,                   rwkv_wkv7_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,                   rwkv_wkv7_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                  mul_mv_f32_f32,                  has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                  mul_mv_f32_f32,                  has_simdgroup_reduction);
@@ -1443,9 +1443,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,                      swiglu_oai,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,                      swiglu_oai,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF,                       geglu_erf,                       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF,                       geglu_erf,                       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,                     geglu_quick,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,                     geglu_quick,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN,                            mean,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN,                            mean,                            has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,                 pool_2d_max_f32,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,                 pool_2d_max_f32,                 true);
     }
     }
@@ -1982,7 +1982,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_L2_NORM:
         case GGML_OP_L2_NORM:
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ARGMAX:
         case GGML_OP_ARGMAX:
-            return true;
+            return has_simdgroup_reduction;
         case GGML_OP_NORM:
         case GGML_OP_NORM:
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ROPE:
         case GGML_OP_ROPE:
@@ -2028,6 +2028,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
             return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
             return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_SCAN:
         case GGML_OP_SSM_SCAN:
+            return has_simdgroup_reduction;
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
         case GGML_OP_RWKV_WKV7:
             return true;
             return true;