1
0
Эх сурвалжийг харах

metal : refactor kernel args into structs (#10238)

* metal : add kernel arg structs (wip)

* metal : fattn args

ggml-ci

* metal : cont + avoid potential int overflow [no ci]

* metal : mul mat struct (wip)

* cont : mul mat vec

* cont : pass by reference

* cont : args is first argument

* cont : use char ptr

* cont : shmem style

* cont : thread counters style

* cont : mul mm id

ggml-ci

* cont : int safety + register optimizations

ggml-ci

* metal : GGML_OP_CONCAT

ggml-ci

* metal : GGML_OP_ADD, GGML_OP_SUB, GGML_OP_MUL, GGML_OP_DIV

* metal : GGML_OP_REPEAT

* metal : GGML_OP_CPY

* metal : GGML_OP_RMS_NORM

* metal : GGML_OP_NORM

* metal : add TODOs for rest of ops

* ggml : add ggml-metal-impl.h

ggml-ci
Georgi Gerganov 1 жил өмнө
parent
commit
cf32a9b93a

+ 4 - 1
Makefile

@@ -906,6 +906,7 @@ endif # GGML_METAL
 ifdef GGML_METAL
 ggml/src/ggml-metal/ggml-metal.o: \
 	ggml/src/ggml-metal/ggml-metal.m \
+	ggml/src/ggml-metal/ggml-metal-impl.h \
 	ggml/include/ggml-metal.h \
 	ggml/include/ggml.h
 	$(CC) $(CFLAGS) -c $< -o $@
@@ -913,9 +914,11 @@ ggml/src/ggml-metal/ggml-metal.o: \
 ifdef GGML_METAL_EMBED_LIBRARY
 ggml/src/ggml-metal-embed.o: \
 	ggml/src/ggml-metal/ggml-metal.metal \
+	ggml/src/ggml-metal/ggml-metal-impl.h \
 	ggml/src/ggml-common.h
 	@echo "Embedding Metal library"
-	@sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal
+	@sed -e '/__embed_ggml-common.h__/r      ggml/src/ggml-common.h'                -e '/__embed_ggml-common.h__/d'      < ggml/src/ggml-metal/ggml-metal.metal           > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp
+	@sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal
 	$(eval TEMP_ASSEMBLY=$(shell mktemp -d))
 	@echo ".section __DATA, __ggml_metallib"                       >  $(TEMP_ASSEMBLY)/ggml-metal-embed.s
 	@echo ".globl _ggml_metallib_start"                            >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s

+ 11 - 7
ggml/src/ggml-metal/CMakeLists.txt

@@ -25,9 +25,10 @@ if (GGML_METAL_USE_BF16)
     add_compile_definitions(GGML_METAL_USE_BF16)
 endif()
 
-# copy ggml-common.h and ggml-metal.metal to bin directory
-configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h    COPYONLY)
-configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
+# copy metal files to bin directory
+configure_file(../ggml-common.h  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h     COPYONLY)
+configure_file(ggml-metal.metal  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal  COPYONLY)
+configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
 
 if (GGML_METAL_EMBED_LIBRARY)
     enable_language(ASM)
@@ -36,24 +37,27 @@ if (GGML_METAL_EMBED_LIBRARY)
 
     set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
     set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
+    set(METALLIB_IMPL   "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
 
     file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
 
     # merge ggml-common.h and ggml-metal.metal into a single file
-    set(METALLIB_EMBED_ASM    "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
-    set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
+    set(METALLIB_EMBED_ASM        "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
+    set(METALLIB_SOURCE_EMBED     "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
+    set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
 
     add_custom_command(
         OUTPUT ${METALLIB_EMBED_ASM}
         COMMAND echo "Embedding Metal library"
-        COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED}
+        COMMAND sed -e '/__embed_ggml-common.h__/r         ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d'         < ${METALLIB_SOURCE}           > ${METALLIB_SOURCE_EMBED_TMP}
+        COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}'   -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED}
         COMMAND echo ".section __DATA,__ggml_metallib"          >  ${METALLIB_EMBED_ASM}
         COMMAND echo ".globl _ggml_metallib_start"              >> ${METALLIB_EMBED_ASM}
         COMMAND echo "_ggml_metallib_start:"                    >> ${METALLIB_EMBED_ASM}
         COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM}
         COMMAND echo ".globl _ggml_metallib_end"                >> ${METALLIB_EMBED_ASM}
         COMMAND echo "_ggml_metallib_end:"                      >> ${METALLIB_EMBED_ASM}
-        DEPENDS ggml-metal.metal ../ggml-common.h
+        DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
         COMMENT "Generate assembly for embedded Metal library"
     )
 

+ 249 - 0
ggml/src/ggml-metal/ggml-metal-impl.h

@@ -0,0 +1,249 @@
+#ifndef GGML_METAL_IMPL
+#define GGML_METAL_IMPL
+
+// kernel argument structs
+//
+// - element counters (e.g. ne00) typically use int32_t to reduce register usage
+//   however, be careful from int overflows when using those in the kernel implementation
+//
+// - strides (e.g. nb00) use uint64_t
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  dim;
+} ggml_metal_kargs_concat;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    uint64_t offs;
+} ggml_metal_kargs_bin;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_repeat;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_cpy;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  n_past;
+    int32_t  n_dims;
+    int32_t  n_ctx_orig;
+    float    freq_base;
+    float    freq_scale;
+    float    ext_factor;
+    float    attn_factor;
+    float    beta_fast;
+    float    beta_slow;
+} ggml_metal_kargs_rope;
+
+typedef struct {
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne11;
+    int32_t  ne_12_2; // assume K and V are same shape
+    int32_t  ne_12_3;
+    uint64_t nb_12_1;
+    uint64_t nb_12_2;
+    uint64_t nb_12_3;
+    uint64_t nb31;
+    int32_t  ne1;
+    int32_t  ne2;
+    float    scale;
+    float    max_bias;
+    float    m0;
+    float    m1;
+    uint16_t n_head_log2;
+    float    logit_softcap;
+} ggml_metal_kargs_flash_attn_ext;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne02;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int16_t  r2;
+    int16_t  r3;
+} ggml_metal_kargs_mul_mm;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int16_t  r2;
+    int16_t  r3;
+} ggml_metal_kargs_mul_mv;
+
+typedef struct {
+    int32_t  nei0;
+    int32_t  nei1;
+    uint64_t nbi1;
+    int32_t  ne00;
+    int32_t  ne02;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    int32_t  ne0;
+    int32_t  ne1;
+} ggml_metal_kargs_mul_mm_id;
+
+typedef struct {
+    int32_t  nei0;
+    int32_t  nei1;
+    uint64_t nbi1;
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    int32_t  ne0;
+    int32_t  ne1;
+    uint64_t nb1;
+} ggml_metal_kargs_mul_mv_id;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne00_4;
+    uint64_t nb01;
+    float    eps;
+} ggml_metal_kargs_norm;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne00_4;
+    uint64_t nb01;
+    float    eps;
+} ggml_metal_kargs_rms_norm;
+
+#endif // GGML_METAL_IMPL

+ 382 - 299
ggml/src/ggml-metal/ggml-metal.m

@@ -2,6 +2,7 @@
 
 #import "ggml-impl.h"
 #import "ggml-backend-impl.h"
+#import "ggml-metal-impl.h"
 
 #import <Foundation/Foundation.h>
 
@@ -1193,35 +1194,39 @@ static void ggml_metal_encode_node(
 
                 const int32_t dim = ((const int32_t *) dst->op_params)[0];
 
+                ggml_metal_kargs_concat args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.ne12 =*/ ne12,
+                    /*.ne13 =*/ ne13,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                    /*.dim  =*/ dim,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                [encoder setBytes:&dim  length:sizeof(dim)  atIndex:27];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                 const int nth = MIN(1024, ne0);
 
@@ -1239,8 +1244,6 @@ static void ggml_metal_encode_node(
 
                 bool bcast_row = false;
 
-                int64_t nb = ne00; // used by the "row" kernels
-
                 id<MTLComputePipelineState> pipeline = nil;
 
                 if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
@@ -1249,7 +1252,6 @@ static void ggml_metal_encode_node(
                     // src1 is a row
                     GGML_ASSERT(ne11 == 1);
 
-                    nb = ne00 / 4;
                     switch (dst->op) {
                         case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
                         case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
@@ -1269,36 +1271,39 @@ static void ggml_metal_encode_node(
                     }
                 }
 
+                ggml_metal_kargs_bin args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.ne12 =*/ ne12,
+                    /*.ne13 =*/ ne13,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                    /*.offs =*/ offs,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
-                [encoder setBytes:&nb   length:sizeof(nb)   atIndex:28];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                 if (bcast_row) {
                     const int64_t n = ggml_nelements(dst)/4;
@@ -1322,25 +1327,29 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 }
 
+                ggml_metal_kargs_repeat args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 
                 const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
 
@@ -1369,25 +1378,29 @@ static void ggml_metal_encode_node(
 
                     const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
 
+                    ggml_metal_kargs_cpy args = {
+                        /*.ne00 =*/ ne00,
+                        /*.ne01 =*/ ne01,
+                        /*.ne02 =*/ ne02,
+                        /*.ne03 =*/ ne03,
+                        /*.nb00 =*/ nb00,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.nb03 =*/ nb03,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                        /*.ne2  =*/ ne2,
+                        /*.ne3  =*/ ne3,
+                        /*.nb0  =*/ nb0,
+                        /*.nb1  =*/ nb1,
+                        /*.nb2  =*/ nb2,
+                        /*.nb3  =*/ nb3,
+                    };
+
                     [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                    [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                    [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
-                    [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
-                    [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
-                    [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
-                    [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
-                    [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
-                    [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
-                    [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
-                    [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
-                    [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
-                    [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
-                    [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
-                    [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
-                    [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
-                    [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+                    [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 
                     const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
 
@@ -1396,35 +1409,39 @@ static void ggml_metal_encode_node(
 
                 const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
 
+                ggml_metal_kargs_bin args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ pnb1,
+                    /*.nb02 =*/ pnb2,
+                    /*.nb03 =*/ pnb3,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.ne12 =*/ ne12,
+                    /*.ne13 =*/ ne13,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ pnb1,
+                    /*.nb2  =*/ pnb2,
+                    /*.nb3  =*/ pnb3,
+                    /*.offs =*/ offs,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
-                [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
-                [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
-                [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
-                [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
-                [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                 const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
 
@@ -1465,10 +1482,10 @@ static void ggml_metal_encode_node(
                 memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
 
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
-                [encoder setBytes:&min length:sizeof(min) atIndex:2];
-                [encoder setBytes:&max length:sizeof(max) atIndex:3];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&min   length:sizeof(min) atIndex:2];
+                [encoder setBytes:&max   length:sizeof(max) atIndex:3];
 
                 const int64_t n = ggml_nelements(dst);
 
@@ -1640,6 +1657,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -1715,6 +1733,8 @@ static void ggml_metal_encode_node(
                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
                 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
+                // TODO: add ggml_metal_kargs struct
+                // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 if (id_src1) {
@@ -1731,6 +1751,7 @@ static void ggml_metal_encode_node(
                 [encoder setBytes:&m0          length:sizeof(m0)          atIndex:8];
                 [encoder setBytes:&m1          length:sizeof(m1)          atIndex:9];
                 [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
+
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1747,6 +1768,7 @@ static void ggml_metal_encode_node(
                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
                 }
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -1771,6 +1793,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
@@ -1841,6 +1864,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1959,24 +1983,29 @@ static void ggml_metal_encode_node(
                                 default: GGML_ABORT("MUL MAT-MAT not implemented");
                             }
 
+                            ggml_metal_kargs_mul_mm args = {
+                                /*.ne00 =*/ ne00,
+                                /*.ne02 =*/ ne02,
+                                /*.nb01 =*/ nb01,
+                                /*.nb02 =*/ nb02,
+                                /*.nb03 =*/ nb03,
+                                /*.ne12 =*/ ne12,
+                                /*.nb10 =*/ nb10,
+                                /*.nb11 =*/ nb11,
+                                /*.nb12 =*/ nb12,
+                                /*.nb13 =*/ nb13,
+                                /*.ne0  =*/ ne0,
+                                /*.ne1  =*/ ne1,
+                                /*.r2   =*/ r2,
+                                /*.r3   =*/ r3,
+                            };
+
                             [encoder setComputePipelineState:pipeline];
-                            [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                            [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
-                            [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4];
-                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
-                            [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
-                            [encoder setBytes:&nb03    length:sizeof(nb03) atIndex:7];
-                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8];
-                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:9];
-                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:10];
-                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:11];
-                            [encoder setBytes:&nb13    length:sizeof(nb13) atIndex:12];
-                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
-                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:15];
-                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:16];
+                            [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                            [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                            [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                            [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
+
                             [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                             [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                         } else {
@@ -2154,28 +2183,32 @@ static void ggml_metal_encode_node(
                                     }
                             };
 
+                            ggml_metal_kargs_mul_mv args = {
+                                /*.ne00 =*/ ne00,
+                                /*.ne01 =*/ ne01,
+                                /*.ne02 =*/ ne02,
+                                /*.nb00 =*/ nb00,
+                                /*.nb01 =*/ nb01,
+                                /*.nb02 =*/ nb02,
+                                /*.nb03 =*/ nb03,
+                                /*.ne10 =*/ ne10,
+                                /*.ne11 =*/ ne11,
+                                /*.ne12 =*/ ne12,
+                                /*.nb10 =*/ nb10,
+                                /*.nb11 =*/ nb11,
+                                /*.nb12 =*/ nb12,
+                                /*.nb13 =*/ nb13,
+                                /*.ne0  =*/ ne0,
+                                /*.ne1  =*/ ne1,
+                                /*.r2   =*/ r2,
+                                /*.r3   =*/ r3,
+                            };
+
                             [encoder setComputePipelineState:pipeline];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
-                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
-                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
-                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
-                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
-                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
-                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
-                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17];
-                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:18];
-                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:19];
-                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:20];
+                            [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
                                 src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
@@ -2288,27 +2321,30 @@ static void ggml_metal_encode_node(
                         default: GGML_ABORT("MUL_MAT_ID not implemented");
                     }
 
+                    ggml_metal_kargs_mul_mm_id args = {
+                        /*.nei0 =*/ ne20,
+                        /*.nei1 =*/ ne21,
+                        /*.nbi1 =*/ nb21,
+                        /*.ne00 =*/ ne00,
+                        /*.ne02 =*/ ne02,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.ne11 =*/ ne11,
+                        /*.ne12 =*/ ne12,
+                        /*.ne13 =*/ ne13,
+                        /*.nb10 =*/ nb10,
+                        /*.nb11 =*/ nb11,
+                        /*.nb12 =*/ nb12,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                    };
+
                     [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
-                    [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
-                    [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                    [encoder setBuffer:id_src2 offset:offs_src2    atIndex:3];
-                    [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
-                    [encoder setBytes:&ne21    length:sizeof(ne21) atIndex:5];
-                    [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
-                    [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:7];
-                    [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:8];
-                    [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:9];
-                    [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:10];
-                    [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:11];
-                    [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:12];
-                    [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:13];
-                    [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:14];
-                    [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:15];
-                    [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:16];
-                    [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:17];
-                    [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:18];
-                    [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:19];
+                    [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
+                    [encoder setBuffer:id_src2 offset:offs_src2    atIndex:4];
 
                     [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
 
@@ -2467,30 +2503,34 @@ static void ggml_metal_encode_node(
                         GGML_ASSERT(ne00 >= nth0*nth1);
                     }
 
+                    ggml_metal_kargs_mul_mv_id args = {
+                        /*.nei0 =*/ ne20,
+                        /*.nei1 =*/ ne21,
+                        /*.nbi1 =*/ nb21,
+                        /*.ne00 =*/ ne00,
+                        /*.ne01 =*/ ne01,
+                        /*.ne02 =*/ ne02,
+                        /*.nb00 =*/ nb00,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.ne10 =*/ ne10,
+                        /*.ne11 =*/ ne11,
+                        /*.ne12 =*/ ne12,
+                        /*.ne13 =*/ ne13,
+                        /*.nb10 =*/ nb10,
+                        /*.nb11 =*/ nb11,
+                        /*.nb12 =*/ nb12,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                        /*.nb1  =*/ nb1,
+                    };
+
                     [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
-                    [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
-                    [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
-                    [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
-                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
-                    [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
-                    [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
-                    [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
-                    [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
-                    [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
-                    [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
-                    [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
-                    [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
-                    [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
-                    [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
-                    [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
-                    [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
-                    [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:20];
-                    [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:21];
-                    [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:22];
+                    [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
+                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
 
                     const int64_t _ne1 = 1;
                     const int tgz = dst_rows;
@@ -2563,6 +2603,7 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("not implemented");
                 }
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
@@ -2586,20 +2627,28 @@ static void ggml_metal_encode_node(
                 float eps;
                 memcpy(&eps, dst->op_params, sizeof(float));
 
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+
                 int nth = 32; // SIMD width
 
-                while (nth < ne00/4 && nth < 1024) {
+                while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
                     nth *= 2;
                 }
 
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+                nth = MIN(nth, ne00/4);
+
+                ggml_metal_kargs_rms_norm args = {
+                    /*.ne00   =*/ ne00,
+                    /*.ne00_4 =*/ ne00/4,
+                    /*.nb01   =*/ nb01,
+                    /*.eps    =*/ eps,
+                };
 
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
-                [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                 const int64_t nrows = ggml_nrows(src0);
@@ -2624,6 +2673,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
                 [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
@@ -2641,22 +2691,35 @@ static void ggml_metal_encode_node(
             } break;
         case GGML_OP_NORM:
             {
+                GGML_ASSERT(ne00 % 4 == 0);
                 GGML_ASSERT(ggml_is_contiguous_1(src0));
 
                 float eps;
                 memcpy(&eps, dst->op_params, sizeof(float));
 
-                const int nth = MIN(256, ne00);
-
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
 
+                int nth = 32; // SIMD width
+
+                while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                    nth *= 2;
+                }
+
+                nth = MIN(nth, ne00/4);
+
+                ggml_metal_kargs_norm args = {
+                    /*.ne00   =*/ ne00,
+                    /*.ne00_4 =*/ ne00/4,
+                    /*.nb01   =*/ nb01,
+                    /*.eps    =*/ eps,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
-                [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
-                [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                 const int64_t nrows = ggml_nrows(src0);
 
@@ -2706,40 +2769,44 @@ static void ggml_metal_encode_node(
                     };
                 }
 
+                ggml_metal_kargs_rope args = {
+                    /*.ne00        =*/ ne00,
+                    /*.ne01        =*/ ne01,
+                    /*.ne02        =*/ ne02,
+                    /*.ne03        =*/ ne03,
+                    /*.nb00        =*/ nb00,
+                    /*.nb01        =*/ nb01,
+                    /*.nb02        =*/ nb02,
+                    /*.nb03        =*/ nb03,
+                    /*.ne0         =*/ ne0,
+                    /*.ne1         =*/ ne1,
+                    /*.ne2         =*/ ne2,
+                    /*.ne3         =*/ ne3,
+                    /*.nb0         =*/ nb0,
+                    /*.nb1         =*/ nb1,
+                    /*.nb2         =*/ nb2,
+                    /*.nb3         =*/ nb3,
+                    /*.n_past      =*/ n_past,
+                    /*.n_dims      =*/ n_dims,
+                    /*.n_ctx_orig  =*/ n_ctx_orig,
+                    /*.freq_base   =*/ freq_base,
+                    /*.freq_scale  =*/ freq_scale,
+                    /*.ext_factor  =*/ ext_factor,
+                    /*.attn_factor =*/ attn_factor,
+                    /*.beta_fast   =*/ beta_fast,
+                    /*.beta_slow   =*/ beta_slow,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
+                [encoder setBytes:&args length:sizeof(args)     atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0     atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1     atIndex:2];
                 if (id_src2 != nil) {
-                    [encoder setBuffer:id_src2 offset:offs_src2        atIndex:2];
+                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
                 } else {
-                    [encoder setBuffer:id_src0 offset:offs_src0        atIndex:2];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
                 }
-                [encoder setBuffer:id_dst      offset:offs_dst         atIndex:3];
-                [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:4];
-                [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:5];
-                [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:6];
-                [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:7];
-                [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:8];
-                [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:9];
-                [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:10];
-                [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:11];
-                [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:12];
-                [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:13];
-                [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:14];
-                [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:15];
-                [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:16];
-                [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:17];
-                [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:18];
-                [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:19];
-                [encoder setBytes:&n_past      length:sizeof(     int) atIndex:20];
-                [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:21];
-                [encoder setBytes:&n_ctx_orig  length:sizeof(     int) atIndex:22];
-                [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23];
-                [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24];
-                [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25];
-                [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:26];
-                [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:27];
-                [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:28];
+                [encoder setBuffer:id_dst  offset:offs_dst      atIndex:4];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
@@ -2796,6 +2863,7 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src1 offset:offs_src1       atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
@@ -2836,6 +2904,7 @@ static void ggml_metal_encode_node(
 
                 const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -2870,6 +2939,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -2906,6 +2976,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_dst  offset:offs_dst    atIndex:0];
                 [encoder setBytes:&ne0   length:sizeof(ne0)   atIndex:1];
@@ -2927,6 +2998,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -2965,6 +3037,7 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
                 [encoder setBuffer:id_dst      offset:offs_dst         atIndex:1];
@@ -2983,6 +3056,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
@@ -3224,37 +3298,41 @@ static void ggml_metal_encode_node(
                     }
                 }
 
+                ggml_metal_kargs_flash_attn_ext args = {
+                    /*.ne01          =*/ ne01,
+                    /*.ne02          =*/ ne02,
+                    /*.ne03          =*/ ne03,
+                    /*.nb01          =*/ nb01,
+                    /*.nb02          =*/ nb02,
+                    /*.nb03          =*/ nb03,
+                    /*.ne11          =*/ ne11,
+                    /*.ne_12_2       =*/ ne12,
+                    /*.ne_12_3       =*/ ne13,
+                    /*.nb_12_1       =*/ nb11,
+                    /*.nb_12_2       =*/ nb12,
+                    /*.nb_12_3       =*/ nb13,
+                    /*.nb31          =*/ nb31,
+                    /*.ne1           =*/ ne1,
+                    /*.ne2           =*/ ne2,
+                    /*.scale         =*/ scale,
+                    /*.max_bias      =*/ max_bias,
+                    /*.m0            =*/ m0,
+                    /*.m1            =*/ m1,
+                    /*.n_head_log2   =*/ n_head_log2,
+                    /*.logit_softcap =*/ logit_softcap,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0           atIndex:0];
-                [encoder setBuffer:id_src1     offset:offs_src1           atIndex:1];
-                [encoder setBuffer:id_src2     offset:offs_src2           atIndex:2];
+                [encoder setBytes:&args length:sizeof(args)     atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0     atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1     atIndex:2];
+                [encoder setBuffer:id_src2 offset:offs_src2     atIndex:3];
                 if (id_src3) {
-                    [encoder setBuffer:id_src3     offset:offs_src3           atIndex:3];
+                    [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
                 } else {
-                    [encoder setBuffer:id_src0     offset:offs_src0           atIndex:3];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
                 }
-                [encoder setBuffer:id_dst        offset:offs_dst              atIndex:4];
-                [encoder setBytes:&ne01          length:sizeof( int64_t)      atIndex:5];
-                [encoder setBytes:&ne02          length:sizeof( int64_t)      atIndex:6];
-                [encoder setBytes:&ne03          length:sizeof( int64_t)      atIndex:7];
-                [encoder setBytes:&nb01          length:sizeof(uint64_t)      atIndex:8];
-                [encoder setBytes:&nb02          length:sizeof(uint64_t)      atIndex:9];
-                [encoder setBytes:&nb03          length:sizeof(uint64_t)      atIndex:10];
-                [encoder setBytes:&ne11          length:sizeof( int64_t)      atIndex:11];
-                [encoder setBytes:&ne12          length:sizeof( int64_t)      atIndex:12];
-                [encoder setBytes:&ne13          length:sizeof( int64_t)      atIndex:13];
-                [encoder setBytes:&nb11          length:sizeof(uint64_t)      atIndex:14];
-                [encoder setBytes:&nb12          length:sizeof(uint64_t)      atIndex:15];
-                [encoder setBytes:&nb13          length:sizeof(uint64_t)      atIndex:16];
-                [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:17];
-                [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:18];
-                [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:19];
-                [encoder setBytes:&scale         length:sizeof(   float)      atIndex:20];
-                [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:21];
-                [encoder setBytes:&m0            length:sizeof(m0)            atIndex:22];
-                [encoder setBytes:&m1            length:sizeof(m1)            atIndex:23];
-                [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:24];
-                [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
+                [encoder setBuffer:id_dst offset:offs_dst       atIndex:5];
 
                 if (!use_vec_kernel) {
                     // half8x8 kernel
@@ -3389,25 +3467,29 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("not implemented");
                 }
 
+                ggml_metal_kargs_cpy args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
-                [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
-                [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
-                [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
-                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
-                [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
-                [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
-                [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
-                [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
-                [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
-                [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
-                [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
-                [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
-                [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
-                [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
@@ -3452,6 +3534,7 @@ static void ggml_metal_encode_node(
                 const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
                 const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0       atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 292 - 508
ggml/src/ggml-metal/ggml-metal.metal


Энэ ялгаанд хэт олон файл өөрчлөгдсөн тул зарим файлыг харуулаагүй болно