|
|
@@ -232,28 +232,6 @@ struct ggml_metal_kernel {
|
|
|
@end
|
|
|
|
|
|
enum ggml_metal_kernel_type {
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
|
|
|
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
|
|
|
- GGML_METAL_KERNEL_TYPE_SUB,
|
|
|
- GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
|
|
|
- GGML_METAL_KERNEL_TYPE_MUL,
|
|
|
- GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
|
|
|
- GGML_METAL_KERNEL_TYPE_DIV,
|
|
|
- GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
|
|
|
GGML_METAL_KERNEL_TYPE_ADD_ID,
|
|
|
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
|
|
@@ -319,9 +297,6 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
|
|
|
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
|
|
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
|
|
- GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
|
- GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
|
|
|
- GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
|
|
|
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
|
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
|
GGML_METAL_KERNEL_TYPE_NORM,
|
|
|
@@ -1177,28 +1152,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
|
|
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
|
|
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
|
|
@@ -1264,9 +1217,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
|
|
|
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
|
@@ -1722,6 +1672,73 @@ static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec_re
|
|
|
GGML_UNUSED(op);
|
|
|
}
|
|
|
|
|
|
+static id<MTLComputePipelineState> ggml_metal_get_pipeline_bin(
|
|
|
+ ggml_backend_t backend, enum ggml_op op,
|
|
|
+ int32_t n_fuse,
|
|
|
+ bool row) {
|
|
|
+ struct ggml_backend_metal_context * ctx = backend->context;
|
|
|
+
|
|
|
+ char base[256];
|
|
|
+ char name[256];
|
|
|
+
|
|
|
+ @autoreleasepool {
|
|
|
+ const char * op_str = "undefined";
|
|
|
+ switch (op) {
|
|
|
+ case GGML_OP_ADD: op_str = "add"; break;
|
|
|
+ case GGML_OP_SUB: op_str = "sub"; break;
|
|
|
+ case GGML_OP_MUL: op_str = "mul"; break;
|
|
|
+ case GGML_OP_DIV: op_str = "div"; break;
|
|
|
+ default: GGML_ABORT("fatal error");
|
|
|
+ };
|
|
|
+
|
|
|
+ if (row) {
|
|
|
+ snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
|
|
|
+ } else {
|
|
|
+ snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
|
|
|
+ }
|
|
|
+
|
|
|
+ snprintf(name, 256, "%s", base);
|
|
|
+
|
|
|
+ id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
|
|
|
+ if (res) {
|
|
|
+ // kernel found
|
|
|
+ return res;
|
|
|
+ }
|
|
|
+
|
|
|
+ return ggml_metal_compile_kernel(backend, base, name, nil);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+static id<MTLComputePipelineState> ggml_metal_get_pipeline_rms_norm(
|
|
|
+ ggml_backend_t backend, struct ggml_tensor * op,
|
|
|
+ int32_t n_fuse) {
|
|
|
+ struct ggml_backend_metal_context * ctx = backend->context;
|
|
|
+
|
|
|
+ char base[256];
|
|
|
+ char name[256];
|
|
|
+
|
|
|
+ @autoreleasepool {
|
|
|
+ switch (n_fuse) {
|
|
|
+ case 1: snprintf(base, 256, "kernel_rms_norm"); break;
|
|
|
+ case 2: snprintf(base, 256, "kernel_rms_norm_mul"); break;
|
|
|
+ case 3: snprintf(base, 256, "kernel_rms_norm_mul_add"); break;
|
|
|
+ default: GGML_ABORT("fatal error");
|
|
|
+ }
|
|
|
+
|
|
|
+ snprintf(name, 256, "%s", base);
|
|
|
+
|
|
|
+ id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
|
|
|
+ if (res) {
|
|
|
+ // kernel found
|
|
|
+ return res;
|
|
|
+ }
|
|
|
+
|
|
|
+ return ggml_metal_compile_kernel(backend, base, name, nil);
|
|
|
+ }
|
|
|
+
|
|
|
+ GGML_UNUSED(op);
|
|
|
+}
|
|
|
+
|
|
|
static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
|
|
GGML_LOG_INFO("%s: deallocating\n", __func__);
|
|
|
|
|
|
@@ -2359,8 +2376,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
|
|
|
|
|
|
bool bcast_row = false;
|
|
|
|
|
|
- id<MTLComputePipelineState> pipeline = nil;
|
|
|
-
|
|
|
ggml_metal_kargs_bin args = {
|
|
|
/*.ne00 =*/ ne00,
|
|
|
/*.ne01 =*/ ne01,
|
|
|
@@ -2441,55 +2456,19 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ id<MTLComputePipelineState> pipeline = nil;
|
|
|
+
|
|
|
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
// src1 is a row
|
|
|
GGML_ASSERT(ne11 == 1);
|
|
|
|
|
|
- switch (dst->op) {
|
|
|
- case GGML_OP_ADD:
|
|
|
- {
|
|
|
- switch (n_fuse) {
|
|
|
- case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
|
|
|
- case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
|
|
|
- case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
|
|
|
- case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
|
|
|
- case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
|
|
|
- case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
|
|
|
- case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
|
|
|
- case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
|
|
|
- default: GGML_ABORT("fatal error");
|
|
|
- }
|
|
|
- } break;
|
|
|
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
|
|
|
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
|
|
|
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
|
|
|
- default: GGML_ABORT("fatal error");
|
|
|
- }
|
|
|
+ pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, true);
|
|
|
|
|
|
bcast_row = true;
|
|
|
} else {
|
|
|
- switch (dst->op) {
|
|
|
- case GGML_OP_ADD:
|
|
|
- {
|
|
|
- switch (n_fuse) {
|
|
|
- case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
|
|
|
- case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
|
|
|
- case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
|
|
|
- case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
|
|
|
- case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
|
|
|
- case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
|
|
|
- case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
|
|
|
- case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
|
|
|
- default: GGML_ABORT("fatal error");
|
|
|
- }
|
|
|
- } break;
|
|
|
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
|
|
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
|
|
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
|
|
- default: GGML_ABORT("fatal error");
|
|
|
- }
|
|
|
+ pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, false);
|
|
|
}
|
|
|
|
|
|
if (n_fuse > 1) {
|
|
|
@@ -2650,8 +2629,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
|
|
|
ggml_metal_encode_concurrency_reset(ctx_enc);
|
|
|
}
|
|
|
|
|
|
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
|
|
-
|
|
|
ggml_metal_kargs_bin args = {
|
|
|
/*.ne00 =*/ ne00,
|
|
|
/*.ne01 =*/ ne01,
|
|
|
@@ -2681,6 +2658,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
|
|
|
/*.o1 =*/ { offs_src1},
|
|
|
};
|
|
|
|
|
|
+ //const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
|
|
+ const id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_bin(backend, GGML_OP_ADD, 1, false);
|
|
|
+
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
|
@@ -4659,14 +4639,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- id<MTLComputePipelineState> pipeline;
|
|
|
-
|
|
|
- switch (n_fuse) {
|
|
|
- case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
|
|
|
- case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
|
|
|
- case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
|
|
|
- default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
|
|
|
- }
|
|
|
+ const id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_rms_norm(backend, node, n_fuse);
|
|
|
|
|
|
int nth = 32; // SIMD width
|
|
|
|