Explorar el Código

metal : simplify kernel arguments using a struct (#3229) (#12194)

* metal : refactor im2col parameters into a struct

* metal: Change im2col offset types from int32_t to uint64_t to support larger memory offsets

* metal : refactor sum_rows parameters into a struct

* metal : refactor soft_max parameters into a struct

* metal : refactor diag_mask_inf parameters into a struct

* metal : refactor ssm_conv parameters into a struct

* metal : refactor ssm_scan parameters into a struct

* metal : refactor get_rows parameters into a struct

* metal : refactor group_norm parameters into a struct

* metal : refactor conv_transpose_1d parameters into a struct

* metal : refactor upscale parameters into a struct

* metal : refactor pad parameters into a struct

* metal : refactor pad_reflect_1d parameters into a struct

* metal : refactor arange parameters into a struct

* metal : refactor timestep_embedding parameters into a struct

* metal : refactor argsort parameters into a struct

* metal : refactor leaky_relu parameters into a struct

* metal : refactor pool_2d parameters into a struct

* metal : fix trailing whitespace

---------

Co-authored-by: alexju <alexju@tencent.com>
BB-fat hace 10 meses
padre
commit
5e2d57b2b2
Se han modificado 3 ficheros con 655 adiciones y 586 borrados
  1. 235 0
      ggml/src/ggml-metal/ggml-metal-impl.h
  2. 260 206
      ggml/src/ggml-metal/ggml-metal.m
  3. 160 380
      ggml/src/ggml-metal/ggml-metal.metal

+ 235 - 0
ggml/src/ggml-metal/ggml-metal-impl.h

@@ -285,4 +285,239 @@ typedef struct {
     float    eps;
 } ggml_metal_kargs_rms_norm;
 
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  n_groups;
+    float    eps;
+} ggml_metal_kargs_group_norm;
+
+typedef struct {
+    int32_t  IC;
+    int32_t  IL;
+    int32_t  K;
+    int32_t  s0;
+    uint64_t nb0;
+    uint64_t nb1;
+} ggml_metal_kargs_conv_transpose_1d;
+
+typedef struct {
+    uint64_t  ofs0;
+    uint64_t  ofs1;
+    int32_t  IW;
+    int32_t  IH;
+    int32_t  CHW;
+    int32_t  s0;
+    int32_t  s1;
+    int32_t  p0;
+    int32_t  p1;
+    int32_t  d0;
+    int32_t  d1;
+    int32_t  N;
+    int32_t  KH;
+    int32_t  KW;
+    int32_t  KHW; // KH * KW, pre-computed on CPU to save GPU resources
+} ggml_metal_kargs_im2col;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne10;
+    int64_t  ne11;
+    int64_t  ne12;
+    int64_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_sum_rows;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    float    scale;
+    float    max_bias;
+    float    m0;
+    float    m1;
+    uint32_t n_head_log2;
+} ggml_metal_kargs_soft_max;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int      n_past;
+} ggml_metal_kargs_diag_mask_inf;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int64_t  ne10;
+    int64_t  ne11;
+    uint64_t nb10;
+    uint64_t nb11;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+} ggml_metal_kargs_ssm_conv;
+
+typedef struct {
+    int64_t  d_state;
+    int64_t  d_inner;
+    int64_t  n_seq_tokens;
+    int64_t  n_seqs;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb20;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb30;
+    uint64_t nb31;
+    uint64_t nb40;
+    uint64_t nb41;
+    uint64_t nb42;
+    uint64_t nb50;
+    uint64_t nb51;
+    uint64_t nb52;
+} ggml_metal_kargs_ssm_scan;
+
+typedef struct {
+    int64_t  ne00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int64_t  ne10;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb1;
+    uint64_t nb2;
+} ggml_metal_kargs_get_rows;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    float    sf0;
+    float    sf1;
+    float    sf2;
+    float    sf3;
+} ggml_metal_kargs_upscale;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_pad;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  p0;
+    int32_t  p1;
+} ggml_metal_kargs_pad_reflect_1d;
+
+typedef struct {
+    uint64_t nb1;
+    int      dim;
+    int      max_period;
+} ggml_metal_kargs_timestep_embedding;
+
+typedef struct {
+    float    slope;
+} ggml_metal_kargs_leaky_relu;
+
+typedef struct {
+    int64_t  ncols;
+    int64_t  ncols_pad;
+} ggml_metal_kargs_argsort;
+
+typedef struct {
+    int64_t  ne0;
+    float    start;
+    float    step;
+} ggml_metal_kargs_arange;
+
+typedef struct {
+    int32_t  k0;
+    int32_t  k1;
+    int32_t  s0;
+    int32_t  s1;
+    int32_t  p0;
+    int32_t  p1;
+    int64_t  IH;
+    int64_t  IW;
+    int64_t  OH;
+    int64_t  OW;
+    int64_t  parallel_elements;
+} ggml_metal_kargs_pool_2d;
+
 #endif // GGML_METAL_IMPL

+ 260 - 206
ggml/src/ggml-metal/ggml-metal.m

@@ -1945,34 +1945,38 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+
+                ggml_metal_kargs_sum_rows args = {
+                   /*.ne00 =*/ ne00,
+                   /*.ne01 =*/ ne01,
+                   /*.ne02 =*/ ne02,
+                   /*.ne03 =*/ ne03,
+                   /*.nb00 =*/ nb00,
+                   /*.nb01 =*/ nb01,
+                   /*.nb02 =*/ nb02,
+                   /*.nb03 =*/ nb03,
+                   /*.ne10 =*/ ne10,
+                   /*.ne11 =*/ ne11,
+                   /*.ne12 =*/ ne12,
+                   /*.ne13 =*/ ne13,
+                   /*.nb10 =*/ nb10,
+                   /*.nb11 =*/ nb11,
+                   /*.nb12 =*/ nb12,
+                   /*.nb13 =*/ nb13,
+                   /*.ne0  =*/ ne0,
+                   /*.ne1  =*/ ne1,
+                   /*.ne2  =*/ ne2,
+                   /*.ne3  =*/ ne3,
+                   /*.nb0  =*/ nb0,
+                   /*.nb1  =*/ nb1,
+                   /*.nb2  =*/ nb2,
+                   /*.nb3  =*/ nb3,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:19];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:20];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:21];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:22];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:23];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:24];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:25];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
@@ -2021,8 +2025,17 @@ static void ggml_metal_encode_node(
                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
                 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-                // TODO: add ggml_metal_kargs struct
-                // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
+                ggml_metal_kargs_soft_max args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.scale =*/ scale,
+                    /*.max_bias =*/ max_bias,
+                    /*.m0 =*/ m0,
+                    /*.m1 =*/ m1,
+                    /*.n_head_log2 =*/ n_head_log2,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 if (id_src1) {
@@ -2031,14 +2044,7 @@ static void ggml_metal_encode_node(
                     [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
                 }
                 [encoder setBuffer:id_dst      offset:offs_dst            atIndex:2];
-                [encoder setBytes:&ne00        length:sizeof(ne00)        atIndex:3];
-                [encoder setBytes:&ne01        length:sizeof(ne01)        atIndex:4];
-                [encoder setBytes:&ne02        length:sizeof(ne02)        atIndex:5];
-                [encoder setBytes:&scale       length:sizeof(scale)       atIndex:6];
-                [encoder setBytes:&max_bias    length:sizeof(max_bias)    atIndex:7];
-                [encoder setBytes:&m0          length:sizeof(m0)          atIndex:8];
-                [encoder setBytes:&m1          length:sizeof(m1)          atIndex:9];
-                [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
+                [encoder setBytes:&args        length:sizeof(args)        atIndex:3];
 
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
@@ -2056,13 +2062,16 @@ static void ggml_metal_encode_node(
                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
                 }
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_diag_mask_inf args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.n_past =*/ n_past,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
+                [encoder setBytes:&args  length:sizeof(args) atIndex:2];
 
                 if (ne00%8 == 0) {
                     [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@@ -2081,27 +2090,30 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_ssm_conv args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                 [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01    length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&nb00    length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&ne10    length:sizeof(ne10) atIndex:9];
-                [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:10];
-                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
-                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
-                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
-                [encoder setBytes:&ne2     length:sizeof(ne2)  atIndex:15];
-                [encoder setBytes:&nb0     length:sizeof(nb0)  atIndex:16];
-                [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:17];
-                [encoder setBytes:&nb2     length:sizeof(nb2)  atIndex:18];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
@@ -2152,7 +2164,31 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_ssm_scan args = {
+                    /*.d_state =*/ d_state,
+                    /*.d_inner =*/ d_inner,
+                    /*.n_seq_tokens =*/ n_seq_tokens,
+                    /*.n_seqs =*/ n_seqs,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.nb20 =*/ nb20,
+                    /*.nb21 =*/ nb21,
+                    /*.nb22 =*/ nb22,
+                    /*.nb30 =*/ nb30,
+                    /*.nb31 =*/ nb31,
+                    /*.nb40 =*/ nb40,
+                    /*.nb41 =*/ nb41,
+                    /*.nb42 =*/ nb42,
+                    /*.nb50 =*/ nb50,
+                    /*.nb51 =*/ nb51,
+                    /*.nb52 =*/ nb52,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2161,30 +2197,7 @@ static void ggml_metal_encode_node(
                 [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
                 [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6];
-
-                [encoder setBytes:&d_state      length:sizeof(d_state)      atIndex:7];
-                [encoder setBytes:&d_inner      length:sizeof(d_inner)      atIndex:8];
-                [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
-                [encoder setBytes:&n_seqs       length:sizeof(n_seqs)       atIndex:10];
-
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
-                [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
-                [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
-                [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
-                [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
-                [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
-                [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
-                [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
-                [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
-                [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
-                [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
-                [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:7];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
@@ -3041,19 +3054,22 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("not implemented");
                 }
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_get_rows args = {
+                    /*.ne00 =*/ ne00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.ne10 =*/ ne10,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
                 [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
-                [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
-                [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
-                [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
-                [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
-                [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
-                [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:9];
-                [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:10];
+                [encoder setBytes:&args length:sizeof(args) atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
             } break;
@@ -3110,18 +3126,21 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_group_norm args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.n_groups =*/ n_groups,
+                    /*.eps =*/ eps,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
                 [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00     length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&ne01     length:sizeof( int64_t) atIndex:3];
-                [encoder setBytes:&ne02     length:sizeof( int64_t) atIndex:4];
-                [encoder setBytes:&nb00     length:sizeof(uint64_t) atIndex:5];
-                [encoder setBytes:&nb01     length:sizeof(uint64_t) atIndex:6];
-                [encoder setBytes:&nb02     length:sizeof(uint64_t) atIndex:7];
-                [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
-                [encoder setBytes:&eps      length:sizeof(   float) atIndex:9];
+                [encoder setBytes:&args     length:sizeof(args)     atIndex:2];
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -3279,8 +3298,8 @@ static void ggml_metal_encode_node(
 
                 const int32_t CHW = IC * KH * KW;
 
-                const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
-                const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+                const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+                const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
 
@@ -3302,27 +3321,30 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_im2col args = {
+                    /*.ofs0 =*/ ofs0,
+                    /*.ofs1 =*/ ofs1,
+                    /*.IW   =*/ IW,
+                    /*.IH   =*/ IH,
+                    /*.CHW  =*/ CHW,
+                    /*.s0   =*/ s0,
+                    /*.s1   =*/ s1,
+                    /*.p0   =*/ p0,
+                    /*.p1   =*/ p1,
+                    /*.d0   =*/ d0,
+                    /*.d1   =*/ d1,
+                    /*.N    =*/ N,
+                    /*.KH   =*/ KH,
+                    /*.KW   =*/ KW,
+                    /*.KHW  =*/ KH * KW,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src1 offset:offs_src1       atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
-                [encoder setBytes:&ofs0    length:sizeof(int32_t) atIndex:2];
-                [encoder setBytes:&ofs1    length:sizeof(int32_t) atIndex:3];
-                [encoder setBytes:&IW      length:sizeof(int32_t) atIndex:4];
-                [encoder setBytes:&IH      length:sizeof(int32_t) atIndex:5];
-                [encoder setBytes:&CHW     length:sizeof(int32_t) atIndex:6];
-                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:7];
-                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:8];
-                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:9];
-                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:10];
-                [encoder setBytes:&d0      length:sizeof(int32_t) atIndex:11];
-                [encoder setBytes:&d1      length:sizeof(int32_t) atIndex:12];
+                [encoder setBytes:&args length:sizeof(args)       atIndex:2];
 
                 if (is_gt_mttpt) {
-                    [encoder setBytes:&N   length:sizeof(int32_t) atIndex:13];
-                    [encoder setBytes:&KH  length:sizeof(int32_t) atIndex:14];
-                    [encoder setBytes:&KW  length:sizeof(int32_t) atIndex:15];
-
                     const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
 
                     const int64_t  quotient  = N / n_threads + (N % n_threads > 0 ? 1 : 0);
@@ -3362,16 +3384,20 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
+                ggml_metal_kargs_conv_transpose_1d args = {
+                    /*.IC =*/ IC,
+                    /*.IL =*/ IL,
+                    /*.K  =*/ K,
+                    /*.s0 =*/ s0,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0         atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1         atIndex:1];
                 [encoder setBuffer:id_dst  offset:offs_dst          atIndex:2];
-                [encoder setBytes:&IC      length:sizeof( int32_t)  atIndex:3];
-                [encoder setBytes:&IL      length:sizeof( int32_t)  atIndex:4];
-                [encoder setBytes:&K       length:sizeof( int32_t)  atIndex:5];
-                [encoder setBytes:&s0      length:sizeof( int32_t)  atIndex:6];
-                [encoder setBytes:&nb0     length:sizeof(uint64_t)  atIndex:7];
-                [encoder setBytes:&nb1     length:sizeof(uint64_t)  atIndex:8];
+                [encoder setBytes:&args    length:sizeof(args)       atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
@@ -3386,30 +3412,33 @@ static void ggml_metal_encode_node(
 
                 const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_upscale args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0 =*/ ne0,
+                    /*.ne1 =*/ ne1,
+                    /*.ne2 =*/ ne2,
+                    /*.ne3 =*/ ne3,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                    /*.nb3 =*/ nb3,
+                    /*.sf0 =*/ sf0,
+                    /*.sf1 =*/ sf1,
+                    /*.sf2 =*/ sf2,
+                    /*.sf3 =*/ sf3
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
-                [encoder setBytes:&sf0  length:sizeof(sf0)  atIndex:18];
-                [encoder setBytes:&sf1  length:sizeof(sf1)  atIndex:19];
-                [encoder setBytes:&sf2  length:sizeof(sf2)  atIndex:20];
-                [encoder setBytes:&sf3  length:sizeof(sf3)  atIndex:21];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
 
@@ -3421,26 +3450,29 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_pad args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0 =*/ ne0,
+                    /*.ne1 =*/ ne1,
+                    /*.ne2 =*/ ne2,
+                    /*.ne3 =*/ ne3,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                    /*.nb3 =*/ nb3
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN(1024, ne0);
 
@@ -3455,24 +3487,31 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
 
+                ggml_metal_kargs_pad_reflect_1d args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0 =*/ ne0,
+                    /*.ne1 =*/ ne1,
+                    /*.ne2 =*/ ne2,
+                    /*.ne3 =*/ ne3,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                    /*.nb3 =*/ nb3,
+                    /*.p0 =*/ p0,
+                    /*.p1 =*/ p1
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:11];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:12];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:13];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:14];
-                [encoder setBytes:&p0   length:sizeof(p0)   atIndex:15];
-                [encoder setBytes:&p1   length:sizeof(p1)   atIndex:16];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN(1024, ne0);
 
@@ -3490,12 +3529,15 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_arange args = {
+                    /*.ne0 =*/ ne0,
+                    /*.start =*/ start,
+                    /*.step =*/ step
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_dst  offset:offs_dst    atIndex:0];
-                [encoder setBytes:&ne0   length:sizeof(ne0)   atIndex:1];
-                [encoder setBytes:&start length:sizeof(start) atIndex:2];
-                [encoder setBytes:&step  length:sizeof(step)  atIndex:3];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:0];
+                [encoder setBytes:&args length:sizeof(args) atIndex:1];
 
                 const int nth = MIN(1024, ne0);
 
@@ -3512,13 +3554,16 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_timestep_embedding args = {
+                    /*.nb1 =*/ nb1,
+                    /*.dim =*/ dim,
+                    /*.max_period =*/ max_period
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&nb1   length:sizeof(nb1) atIndex:2];
-                [encoder setBytes:&dim   length:sizeof(dim) atIndex:3];
-                [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN(1024, half);
 
@@ -3551,12 +3596,15 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_argsort args = {
+                    /*.ncols =*/ ne00,
+                    /*.ncols_pad =*/ ne00_padded
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst      offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
@@ -3570,11 +3618,14 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_leaky_relu args = {
+                    /*.slope =*/ slope
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
-                [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+                [encoder setBytes:&args length:sizeof(args)   atIndex:2];
 
                 const int64_t n = ggml_nelements(dst);
 
@@ -4150,21 +4201,24 @@ static void ggml_metal_encode_node(
                 const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
                 const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_pool_2d args_pool_2d = {
+                    /* .k0 = */ k0,
+                    /* .k1 = */ k1,
+                    /* .s0 = */ s0,
+                    /* .s1 = */ s1,
+                    /* .p0 = */ p0,
+                    /* .p1 = */ p1,
+                    /* .IH = */ IH,
+                    /* .IW = */ IW,
+                    /* .OH = */ OH,
+                    /* .OW = */ OW,
+                    /* .parallel_elements = */ parallel_elements
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0       atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
-                [encoder setBytes:&k0      length:sizeof(int32_t) atIndex:2];
-                [encoder setBytes:&k1      length:sizeof(int32_t) atIndex:3];
-                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:4];
-                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:5];
-                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:6];
-                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:7];
-                [encoder setBytes:&IH      length:sizeof(int64_t) atIndex:8];
-                [encoder setBytes:&IW      length:sizeof(int64_t) atIndex:9];
-                [encoder setBytes:&OH      length:sizeof(int64_t) atIndex:10];
-                [encoder setBytes:&OW      length:sizeof(int64_t) atIndex:11];
-                [encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
             } break;

La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 160 - 380
ggml/src/ggml-metal/ggml-metal.metal


Algunos archivos no se mostraron porque demasiados archivos cambiaron en este cambio