فهرست منبع

metal : refactor mat-vec code (#12569)

* metal : refactor mat-vec code

ggml-ci

* metal : rename all_sum -> sum_all

ggml-ci

* metal : fix comments [no ci]

* metal : fix nr constant [no ci]

* metal : mv q6_K support nr0 > 1

ggml-ci

* metal : reduce register pressure

ggml-ci

* metal : fix typo [no ci]

* metal : reduce register pressure

ggml-ci
Georgi Gerganov 9 ماه پیش
والد
کامیت
b3298fa47a
3فایلهای تغییر یافته به همراه384 افزوده شده و 361 حذف شده
  1. 64 0
      ggml/src/ggml-metal/ggml-metal-impl.h
  2. 127 171
      ggml/src/ggml-metal/ggml-metal.m
  3. 193 190
      ggml/src/ggml-metal/ggml-metal.metal

+ 64 - 0
ggml/src/ggml-metal/ggml-metal-impl.h

@@ -1,6 +1,70 @@
 #ifndef GGML_METAL_IMPL
 #ifndef GGML_METAL_IMPL
 #define GGML_METAL_IMPL
 #define GGML_METAL_IMPL
 
 
+// kernel parameters for mat-vec threadgroups
+//
+// N_R0: number of src0 rows to process per simdgroup
+// N_SG: number of simdgroups per threadgroup
+//
+// TODO: for optimal performance, become function of the device and work size
+
+#define N_R0_Q4_0 4
+#define N_SG_Q4_0 2
+
+#define N_R0_Q4_1 4
+#define N_SG_Q4_1 2
+
+#define N_R0_Q5_0 4
+#define N_SG_Q5_0 2
+
+#define N_R0_Q5_1 4
+#define N_SG_Q5_1 2
+
+#define N_R0_Q8_0 4
+#define N_SG_Q8_0 2
+
+#define N_R0_Q2_K 4
+#define N_SG_Q2_K 2
+
+#define N_R0_Q3_K 2
+#define N_SG_Q3_K 2
+
+#define N_R0_Q4_K 4
+#define N_SG_Q4_K 2
+
+#define N_R0_Q5_K 2
+#define N_SG_Q5_K 2
+
+#define N_R0_Q6_K 1
+#define N_SG_Q6_K 2
+
+#define N_R0_IQ1_S 4
+#define N_SG_IQ1_S 2
+
+#define N_R0_IQ1_M 4
+#define N_SG_IQ1_M 2
+
+#define N_R0_IQ2_XXS 4
+#define N_SG_IQ2_XXS 2
+
+#define N_R0_IQ2_XS 4
+#define N_SG_IQ2_XS 2
+
+#define N_R0_IQ2_S 4
+#define N_SG_IQ2_S 2
+
+#define N_R0_IQ3_XXS 4
+#define N_SG_IQ3_XXS 2
+
+#define N_R0_IQ3_S 4
+#define N_SG_IQ3_S 2
+
+#define N_R0_IQ4_NL 2
+#define N_SG_IQ4_NL 2
+
+#define N_R0_IQ4_XS 2
+#define N_SG_IQ4_XS 2
+
 // kernel argument structs
 // kernel argument structs
 //
 //
 // - element counters (e.g. ne00) typically use int32_t to reduce register usage
 // - element counters (e.g. ne00) typically use int32_t to reduce register usage

+ 127 - 171
ggml/src/ggml-metal/ggml-metal.m

@@ -2561,171 +2561,180 @@ static void ggml_metal_encode_node(
                     [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                     [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                     [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                     [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                 } else {
                 } else {
-                    int nth0 = 32;
-                    int nth1 = 1;
-                    int nrows = 1;
-                    //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
                     id<MTLComputePipelineState> pipeline = nil;
                     id<MTLComputePipelineState> pipeline = nil;
 
 
+                    int nsg = 0; // number of simdgroups
+                    int nr0 = 0; // number of src0 rows per simdgroup
+                    int nr1 = 1; // number of src1 rows per threadgroup
+
+                    size_t smem = 0; // shared memory
+
                     // use custom matrix x vector kernel
                     // use custom matrix x vector kernel
                     switch (src0t) {
                     switch (src0t) {
                         case GGML_TYPE_F32:
                         case GGML_TYPE_F32:
                             {
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                nsg = 1;
+                                nr0 = 1;
+                                nr1 = 4;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
-                                nrows = 4;
                             } break;
                             } break;
                         case GGML_TYPE_F16:
                         case GGML_TYPE_F16:
                             {
                             {
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 if (src1t == GGML_TYPE_F32) {
                                 if (src1t == GGML_TYPE_F32) {
                                     if (ne11 * ne12 < 4) {
                                     if (ne11 * ne12 < 4) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
-                                        nrows = ne11;
+                                        nr1 = ne11;
                                     } else {
                                     } else {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
-                                        nrows = 4;
+                                        nr1 = 4;
                                     }
                                     }
                                 } else {
                                 } else {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
-                                    nrows = 4;
+                                    nr1 = 4;
                                 }
                                 }
                             } break;
                             } break;
                         case GGML_TYPE_BF16:
                         case GGML_TYPE_BF16:
                             {
                             {
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 if (src1t == GGML_TYPE_F32) {
                                 if (src1t == GGML_TYPE_F32) {
                                     if (ne11 * ne12 < 4) {
                                     if (ne11 * ne12 < 4) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
-                                        nrows = ne11;
+                                        nr1 = ne11;
                                     } else {
                                     } else {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
-                                        nrows = 4;
+                                        nr1 = 4;
                                     }
                                     }
                                 } else {
                                 } else {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
-                                    nrows = 4;
+                                    nr1 = 4;
                                 }
                                 }
                             } break;
                             } break;
                         case GGML_TYPE_Q4_0:
                         case GGML_TYPE_Q4_0:
                             {
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_0;
+                                nr0 = N_R0_Q4_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q4_1:
                         case GGML_TYPE_Q4_1:
                             {
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_1;
+                                nr0 = N_R0_Q4_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q5_0:
                         case GGML_TYPE_Q5_0:
                             {
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_0;
+                                nr0 = N_R0_Q5_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q5_1:
                         case GGML_TYPE_Q5_1:
                             {
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_1;
+                                nr0 = N_R0_Q5_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q8_0:
                         case GGML_TYPE_Q8_0:
                             {
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q8_0;
+                                nr0 = N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q2_K:
                         case GGML_TYPE_Q2_K:
                             {
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q2_K;
+                                nr0 = N_R0_Q2_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q3_K:
                         case GGML_TYPE_Q3_K:
                             {
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q3_K;
+                                nr0 = N_R0_Q3_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q4_K:
                         case GGML_TYPE_Q4_K:
                             {
                             {
-                                nth0 = 4; //1;
-                                nth1 = 8; //32;
+                                nsg = N_SG_Q4_K;
+                                nr0 = N_R0_Q4_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q5_K:
                         case GGML_TYPE_Q5_K:
                             {
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q5_K;
+                                nr0 = N_R0_Q5_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q6_K:
                         case GGML_TYPE_Q6_K:
                             {
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q6_K;
+                                nr0 = N_R0_Q6_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ2_XXS:
                         case GGML_TYPE_IQ2_XXS:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XXS;
+                                nr0 = N_R0_IQ2_XXS;
+                                smem = 256*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ2_XS:
                         case GGML_TYPE_IQ2_XS:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XS;
+                                nr0 = N_R0_IQ2_XS;
+                                smem = 512*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ3_XXS:
                         case GGML_TYPE_IQ3_XXS:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_XXS;
+                                nr0 = N_R0_IQ3_XXS;
+                                smem = 256*4+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ3_S:
                         case GGML_TYPE_IQ3_S:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_S;
+                                nr0 = N_R0_IQ3_S;
+                                smem = 512*4;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ2_S:
                         case GGML_TYPE_IQ2_S:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_S;
+                                nr0 = N_R0_IQ2_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ1_S:
                         case GGML_TYPE_IQ1_S:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_S;
+                                nr0 = N_R0_IQ1_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ1_M:
                         case GGML_TYPE_IQ1_M:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_M;
+                                nr0 = N_R0_IQ1_M;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ4_NL:
                         case GGML_TYPE_IQ4_NL:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_NL;
+                                nr0 = N_R0_IQ4_NL;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ4_XS:
                         case GGML_TYPE_IQ4_XS:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_XS;
+                                nr0 = N_R0_IQ4_XS;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
                             } break;
                             } break;
                         default:
                         default:
@@ -2762,41 +2771,10 @@ static void ggml_metal_encode_node(
                     [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
                     [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
 
-                    if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                        src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                        src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
-                        const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
-                        const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
-                        const int mem_size = 32*sizeof(float);
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q4_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q3_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q5_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q6_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    } else {
-                        const int64_t ny = (ne11 + nrows - 1)/nrows;
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    if (smem > 0) {
+                        [encoder setThreadgroupMemoryLength:smem atIndex:0];
                     }
                     }
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
                 }
                 }
             } break;
             } break;
         case GGML_OP_MUL_MAT_ID:
         case GGML_OP_MUL_MAT_ID:
@@ -2902,146 +2880,155 @@ static void ggml_metal_encode_node(
 
 
                     [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                     [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                 } else {
                 } else {
-                    int nth0 = 32;
-                    int nth1 = 1;
-                    int nrows = 1;
-                    //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
                     id<MTLComputePipelineState> pipeline = nil;
                     id<MTLComputePipelineState> pipeline = nil;
 
 
+                    int nsg = 0; // number of simdgroups
+                    int nr0 = 0; // number of src0 rows per simdgroup
+                    int nr1 = 1; // number of src1 rows per threadgroup
+
+                    size_t smem = 0; // shared memory
+
                     // use custom matrix x vector kernel
                     // use custom matrix x vector kernel
                     switch (src0t) {
                     switch (src0t) {
                         case GGML_TYPE_F32:
                         case GGML_TYPE_F32:
                             {
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                nsg = 1;
+                                nr0 = 1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_F16:
                         case GGML_TYPE_F16:
                             {
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_BF16:
                         case GGML_TYPE_BF16:
                             {
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q4_0:
                         case GGML_TYPE_Q4_0:
                             {
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_0;
+                                nr0 = N_R0_Q4_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q4_1:
                         case GGML_TYPE_Q4_1:
                             {
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_1;
+                                nr0 = N_R0_Q4_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q5_0:
                         case GGML_TYPE_Q5_0:
                             {
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_0;
+                                nr0 = N_R0_Q5_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q5_1:
                         case GGML_TYPE_Q5_1:
                             {
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_1;
+                                nr0 = N_R0_Q5_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q8_0:
                         case GGML_TYPE_Q8_0:
                             {
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q8_0;
+                                nr0 = N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q2_K:
                         case GGML_TYPE_Q2_K:
                             {
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q2_K;
+                                nr0 = N_R0_Q2_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q3_K:
                         case GGML_TYPE_Q3_K:
                             {
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q3_K;
+                                nr0 = N_R0_Q3_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q4_K:
                         case GGML_TYPE_Q4_K:
                             {
                             {
-                                nth0 = 4; //1;
-                                nth1 = 8; //32;
+                                nsg = N_SG_Q4_K;
+                                nr0 = N_R0_Q4_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q5_K:
                         case GGML_TYPE_Q5_K:
                             {
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q5_K;
+                                nr0 = N_R0_Q5_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_Q6_K:
                         case GGML_TYPE_Q6_K:
                             {
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q6_K;
+                                nr0 = N_R0_Q6_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ2_XXS:
                         case GGML_TYPE_IQ2_XXS:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XXS;
+                                nr0 = N_R0_IQ2_XXS;
+                                smem = 256*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ2_XS:
                         case GGML_TYPE_IQ2_XS:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XS;
+                                nr0 = N_R0_IQ2_XS;
+                                smem = 512*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ3_XXS:
                         case GGML_TYPE_IQ3_XXS:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_XXS;
+                                nr0 = N_R0_IQ3_XXS;
+                                smem = 256*4+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ3_S:
                         case GGML_TYPE_IQ3_S:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_S;
+                                nr0 = N_R0_IQ3_S;
+                                smem = 512*4;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ2_S:
                         case GGML_TYPE_IQ2_S:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_S;
+                                nr0 = N_R0_IQ2_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ1_S:
                         case GGML_TYPE_IQ1_S:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_S;
+                                nr0 = N_R0_IQ1_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ1_M:
                         case GGML_TYPE_IQ1_M:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_M;
+                                nr0 = N_R0_IQ1_M;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ4_NL:
                         case GGML_TYPE_IQ4_NL:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_NL;
+                                nr0 = N_R0_IQ4_NL;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
                             } break;
                             } break;
                         case GGML_TYPE_IQ4_XS:
                         case GGML_TYPE_IQ4_XS:
                             {
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_XS;
+                                nr0 = N_R0_IQ4_XS;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
                             } break;
                             } break;
                         default:
                         default:
@@ -3052,7 +3039,7 @@ static void ggml_metal_encode_node(
                     };
                     };
 
 
                     if (ggml_is_quantized(src0t)) {
                     if (ggml_is_quantized(src0t)) {
-                        GGML_ASSERT(ne00 >= nth0*nth1);
+                        GGML_ASSERT(ne00 >= nsg*nr0);
                     }
                     }
 
 
                     ggml_metal_kargs_mul_mv_id args = {
                     ggml_metal_kargs_mul_mv_id args = {
@@ -3085,43 +3072,12 @@ static void ggml_metal_encode_node(
                     [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
                     [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
 
 
                     const int64_t _ne1 = 1;
                     const int64_t _ne1 = 1;
-                    const int tgz = dst_rows;
+                    const int64_t ne123 = dst_rows;
 
 
-                    if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                            src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                            src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
-                        const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
-                        const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
-                        const int mem_size = 32*sizeof(float);
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q4_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q3_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q5_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q6_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    } else {
-                        const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    if (smem > 0) {
+                        [encoder setThreadgroupMemoryLength:smem atIndex:0];
                     }
                     }
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
                 }
                 }
             } break;
             } break;
         case GGML_OP_GET_ROWS:
         case GGML_OP_GET_ROWS:

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 193 - 190
ggml/src/ggml-metal/ggml-metal.metal


برخی فایل ها در این مقایسه diff نمایش داده نمی شوند زیرا تعداد فایل ها بسیار زیاد است