Browse Source

ggml : group all experts in a single ggml_mul_mat_id (#6505)

* ggml : group all experts in a single ggml_mul_mat_id
cuda : improve mmid row copy

* cuda : fix bin bcast with non-cont src0

* test-backend-ops : only run all mul mat tests for base types

* llama : disable moe offloading with SYCL

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
slaren 1 year ago
parent
commit
0d56246f4b
12 changed files with 933 additions and 763 deletions
  1. 36 21
      examples/imatrix/imatrix.cpp
  2. 134 45
      ggml-cuda.cu
  3. 68 24
      ggml-cuda/binbcast.cu
  4. 2 0
      ggml-cuda/convert.cu
  5. 61 68
      ggml-metal.m
  6. 362 420
      ggml-metal.metal
  7. 1 1
      ggml-sycl.cpp
  8. 62 61
      ggml.c
  9. 2 4
      ggml.h
  10. 138 85
      llama.cpp
  11. 1 13
      scripts/compare-commits.sh
  12. 66 21
      tests/test-backend-ops.cpp

+ 36 - 21
examples/imatrix/imatrix.cpp

@@ -44,7 +44,7 @@ private:
     std::mutex                             m_mutex;
     std::mutex                             m_mutex;
     int                                    m_last_call = 0;
     int                                    m_last_call = 0;
     std::vector<float>                     m_src1_data;
     std::vector<float>                     m_src1_data;
-    std::vector<int>                       m_ids; // the expert ids from ggml_mul_mat_id
+    std::vector<char>                      m_ids; // the expert ids from ggml_mul_mat_id
                                                   //
                                                   //
     void save_imatrix(const char * file_name) const;
     void save_imatrix(const char * file_name) const;
     void keep_imatrix(int ncall) const;
     void keep_imatrix(int ncall) const;
@@ -81,6 +81,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
     if (ask) {
     if (ask) {
         if (t->op == GGML_OP_MUL_MAT_ID) return true; // collect all indirect matrix multiplications
         if (t->op == GGML_OP_MUL_MAT_ID) return true; // collect all indirect matrix multiplications
         if (t->op != GGML_OP_MUL_MAT) return false;
         if (t->op != GGML_OP_MUL_MAT) return false;
+        // why are small batches ignored (<16 tokens)?
         if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false;
         if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false;
         if (!(wname.substr(0, 4) == "blk." || (m_params.collect_output_weight && wname == "output.weight"))) return false;
         if (!(wname.substr(0, 4) == "blk." || (m_params.collect_output_weight && wname == "output.weight"))) return false;
         return true;
         return true;
@@ -101,14 +102,19 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
     // this has been adapted to the new format of storing merged experts in a single 3d tensor
     // this has been adapted to the new format of storing merged experts in a single 3d tensor
     // ref: https://github.com/ggerganov/llama.cpp/pull/6387
     // ref: https://github.com/ggerganov/llama.cpp/pull/6387
     if (t->op == GGML_OP_MUL_MAT_ID) {
     if (t->op == GGML_OP_MUL_MAT_ID) {
-        const int idx  = ((int32_t *) t->op_params)[0];
+        //   ids  -> [n_experts_used, n_tokens]
+        //   src1 -> [cols, n_expert_used, n_tokens]
         const ggml_tensor * ids = t->src[2];
         const ggml_tensor * ids = t->src[2];
         const int n_as = src0->ne[2];
         const int n_as = src0->ne[2];
+        const int n_ids = ids->ne[0];
 
 
         // the top-k selected expert ids are stored in the ids tensor
         // the top-k selected expert ids are stored in the ids tensor
         // for simplicity, always copy ids to host, because it is small
         // for simplicity, always copy ids to host, because it is small
-        GGML_ASSERT(ids->ne[1] == src1->ne[1]);
-        m_ids.resize(ggml_nbytes(ids)/sizeof(int));
+        // take into account that ids is not contiguous!
+
+        GGML_ASSERT(ids->ne[1] == src1->ne[2]);
+
+        m_ids.resize(ggml_nbytes(ids));
         ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids));
         ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids));
 
 
         auto & e = m_stats[wname];
         auto & e = m_stats[wname];
@@ -118,26 +124,35 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
         //       using the following line, we can correct for that if needed by replacing the line above with:
         //       using the following line, we can correct for that if needed by replacing the line above with:
         //if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
         //if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
 
 
+        if (e.values.empty()) {
+            e.values.resize(src1->ne[0]*n_as, 0);
+        }
+        else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
+            fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
+            exit(1); //GGML_ASSERT(false);
+        }
+        if (m_params.verbosity > 1) {
+            printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
+        }
         // loop over all possible experts, regardless if they are used or not in the batch
         // loop over all possible experts, regardless if they are used or not in the batch
         for (int ex = 0; ex < n_as; ++ex) {
         for (int ex = 0; ex < n_as; ++ex) {
             size_t e_start = ex*src1->ne[0];
             size_t e_start = ex*src1->ne[0];
-            if (e.values.empty()) {
-                e.values.resize(src1->ne[0]*n_as, 0);
-            }
-            else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
-                fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
-                exit(1); //GGML_ASSERT(false);
-            }
-            if (m_params.verbosity > 1) {
-                printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
-            }
-            for (int row = 0; row < (int)src1->ne[1]; ++row) {
-                const int excur = m_ids[row*n_as + idx];
-                GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check
-                if (excur != ex) continue;
-                const float * x = data + row * src1->ne[0];
-                for (int j = 0; j < (int)src1->ne[0]; ++j) {
-                    e.values[e_start + j] += x[j]*x[j];
+
+            for (int idx = 0; idx < n_ids; ++idx) {
+                for (int row = 0; row < (int)src1->ne[2]; ++row) {
+                    const int excur = *(const int32_t *) (m_ids.data() + row*ids->nb[1] + idx*ids->nb[0]);
+
+                    GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check
+
+                    if (excur != ex) continue;
+
+                    const int64_t i11 = idx % src1->ne[1];
+                    const int64_t i12 = row;
+                    const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]);
+
+                    for (int j = 0; j < (int)src1->ne[0]; ++j) {
+                        e.values[e_start + j] += x[j]*x[j];
+                    }
                 }
                 }
             }
             }
             if (e.ncall > m_last_call) {
             if (e.ncall > m_last_call) {

+ 134 - 45
ggml-cuda.cu

@@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
 
 
     if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
     if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
-        ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool());
+        ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
         if (src0->type != GGML_TYPE_F16) {
         if (src0->type != GGML_TYPE_F16) {
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             GGML_ASSERT(to_fp16_cuda != nullptr);
@@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
         }
         }
         const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
         const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
 
 
-        ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool());
+        ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
         if (src1->type != GGML_TYPE_F16) {
         if (src1->type != GGML_TYPE_F16) {
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             GGML_ASSERT(to_fp16_cuda != nullptr);
@@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
             to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
             to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
         }
         }
         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
-        ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(), row_diff*src1_ncols);
+        ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
 
 
         const half alpha_f16 = 1.0f;
         const half alpha_f16 = 1.0f;
         const half beta_f16 = 0.0f;
         const half beta_f16 = 0.0f;
@@ -1960,20 +1960,73 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     }
     }
 }
 }
 
 
+struct mmid_row_mapping {
+    int32_t i1;
+    int32_t i2;
+};
+
+static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
+                                                 int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
+                                                 const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
+                                                 int64_t ne11, int64_t ne10,
+                                                 size_t nb11, size_t nb12) {
+    int32_t iid1 = blockIdx.x;
+    int32_t id = blockIdx.y;
+
+    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
+
+    if (row_id_i != i02) {
+        return;
+    }
+
+    const int64_t i11 = id % ne11;
+    const int64_t i12 = iid1;
+
+    __shared__ int src1_row;
+    if (threadIdx.x == 0) {
+        src1_row = atomicAdd(cur_src1_row, 1);
+        row_mapping[src1_row] = {id, iid1};
+    }
+    __syncthreads();
+
+    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
+    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
+
+    for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
+        src1_row_contiguous[i] = src1_row_original[i];
+    }
+}
+
+static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
+                                                  const mmid_row_mapping * __restrict__ row_mapping,
+                                                  int64_t ne0,
+                                                  size_t nb1, size_t nb2) {
+    int32_t i = blockIdx.x;
+
+    const int32_t i1 = row_mapping[i].i1;
+    const int32_t i2 = row_mapping[i].i2;
+
+    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
+    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
+
+    for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
+        dst_row_original[j] = dst_row_contiguous[j];
+    }
+}
+
 static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
     const ggml_tensor * src1 = dst->src[1];
     const ggml_tensor * ids  = dst->src[2];
     const ggml_tensor * ids  = dst->src[2];
 
 
+    GGML_TENSOR_BINARY_OP_LOCALS
+
     GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
     GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
 
 
     cudaStream_t stream = ctx.stream();
     cudaStream_t stream = ctx.stream();
 
 
-    const size_t nb11 = src1->nb[1];
-    const size_t nb1  =  dst->nb[1];
-
-    const int32_t id = ((int32_t *) dst->op_params)[0];
-    const int32_t n_as = src0->ne[2];
+    const int64_t n_as = ne02;
+    const int64_t n_ids = ids->ne[0];
 
 
     std::vector<char> ids_host(ggml_nbytes(ids));
     std::vector<char> ids_host(ggml_nbytes(ids));
     const char * ids_dev = (const char *) ids->data;
     const char * ids_dev = (const char *) ids->data;
@@ -1982,7 +2035,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
 
 
     ggml_tensor src0_row = *src0;
     ggml_tensor src0_row = *src0;
     ggml_tensor src1_row = *src1;
     ggml_tensor src1_row = *src1;
-    ggml_tensor dst_row = *dst;
+    ggml_tensor dst_row  = *dst;
 
 
     char * src0_original = (char *) src0->data;
     char * src0_original = (char *) src0->data;
     char * src1_original = (char *) src1->data;
     char * src1_original = (char *) src1->data;
@@ -1990,19 +2043,39 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
 
 
     src0_row.ne[2] = 1;
     src0_row.ne[2] = 1;
     src0_row.ne[3] = 1;
     src0_row.ne[3] = 1;
-    src0_row.nb[3] = src0->nb[2];
+    src0_row.nb[3] = nb02;
 
 
-    if (src1->ne[1] == 1) {
-        for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-            const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
+    src1_row.ne[1] = 1;
+    src1_row.ne[2] = 1;
+    src1_row.ne[3] = 1;
+    src1_row.nb[2] = nb11;
+    src1_row.nb[3] = nb11;
 
 
-            GGML_ASSERT(row_id >= 0 && row_id < n_as);
+    dst_row.ne[1] = 1;
+    dst_row.ne[2] = 1;
+    dst_row.ne[3] = 1;
+    dst_row.nb[2] = nb1;
+    dst_row.nb[3] = nb1;
 
 
-            src0_row.data = src0_original + row_id*src0->nb[2];
-            src1_row.data = src1_original + i01*src1->nb[1];
-            dst_row.data  =  dst_original + i01*dst->nb[1];
+    if (ne12 == 1) {
+        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+            for (int64_t id = 0; id < n_ids; id++) {
+                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
 
 
-            ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+                GGML_ASSERT(i02 >= 0 && i02 < n_as);
+
+                const int64_t i11 = id % ne11;
+                const int64_t i12 = iid1;
+
+                const int64_t i1 = id;
+                const int64_t i2 = i12;
+
+                src0_row.data = src0_original + i02*nb02;
+                src1_row.data = src1_original + i11*nb11 + i12*nb12;
+                dst_row.data  =  dst_original + i1*nb1   + i2*nb2;
+
+                ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+            }
         }
         }
     } else {
     } else {
         ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
         ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
@@ -2011,54 +2084,69 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
         src1_row.data = src1_contiguous.get();
         src1_row.data = src1_contiguous.get();
         dst_row.data  =  dst_contiguous.get();
         dst_row.data  =  dst_contiguous.get();
 
 
-        for (int32_t row_id = 0; row_id < n_as; ++row_id) {
+        for (int64_t i02 = 0; i02 < n_as; i02++) {
             int64_t num_src1_rows = 0;
             int64_t num_src1_rows = 0;
-            for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-                const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
 
 
-                if (row_id_i != row_id) {
-                    continue;
-                }
+            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+                for (int64_t id = 0; id < n_ids; id++) {
+                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
 
 
-                GGML_ASSERT(row_id >= 0 && row_id < n_as);
+                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
 
 
-                CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
-                                        nb11, cudaMemcpyDeviceToDevice, stream));
-                num_src1_rows++;
+                    if (row_id_i != i02) {
+                        continue;
+                    }
+
+                    num_src1_rows++;
+                }
             }
             }
 
 
             if (num_src1_rows == 0) {
             if (num_src1_rows == 0) {
                 continue;
                 continue;
             }
             }
 
 
-            src0_row.data = src0_original + row_id*src0->nb[2];
+            ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
+            ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
+            CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
 
 
-            src1_row.ne[1] = num_src1_rows;
-            dst_row.ne[1] = num_src1_rows;
+            {
+                dim3 block_dims(std::min((unsigned int)ne10, 768u));
+                dim3 grid_dims(ids->ne[1], n_ids);
+                k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+                        src1_original, src1_contiguous.get(),
+                        dev_cur_src1_row.get(), dev_row_mapping.get(),
+                        ids_dev, i02, ids->nb[1], ids->nb[0],
+                        ne11, ne10,
+                        nb11, nb12);
+                CUDA_CHECK(cudaGetLastError());
+            }
+
+            src0_row.data = src0_original + i02*nb02;
 
 
+            GGML_ASSERT(nb11 == sizeof(float)*ne10);
+            GGML_ASSERT(nb1 == sizeof(float)*ne0);
+
+            src1_row.ne[1] = num_src1_rows;
             src1_row.nb[1] = nb11;
             src1_row.nb[1] = nb11;
             src1_row.nb[2] = num_src1_rows*nb11;
             src1_row.nb[2] = num_src1_rows*nb11;
             src1_row.nb[3] = num_src1_rows*nb11;
             src1_row.nb[3] = num_src1_rows*nb11;
 
 
+            dst_row.ne[1] = num_src1_rows;
             dst_row.nb[1] = nb1;
             dst_row.nb[1] = nb1;
             dst_row.nb[2] = num_src1_rows*nb1;
             dst_row.nb[2] = num_src1_rows*nb1;
             dst_row.nb[3] = num_src1_rows*nb1;
             dst_row.nb[3] = num_src1_rows*nb1;
 
 
             ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
             ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
 
 
-            num_src1_rows = 0;
-            for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-                const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
-
-                if (row_id_i != row_id) {
-                    continue;
-                }
-
-                GGML_ASSERT(row_id >= 0 && row_id < n_as);
-
-                CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
-                                        nb1, cudaMemcpyDeviceToDevice, stream));
-                num_src1_rows++;
+            {
+                dim3 block_dims(std::min((unsigned int)ne0, 768u));
+                dim3 grid_dims(num_src1_rows);
+                k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+                        dst_original, dst_contiguous.get(),
+                        dev_row_mapping.get(),
+                        ne0,
+                        nb1, nb2);
+                CUDA_CHECK(cudaGetLastError());
             }
             }
         }
         }
     }
     }
@@ -2487,7 +2575,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
 GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
 GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
     const int min_batch_size = 32;
     const int min_batch_size = 32;
 
 
-    return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
+    return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
+           (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
 
 
     GGML_UNUSED(backend);
     GGML_UNUSED(backend);
 }
 }

+ 68 - 24
ggml-cuda/binbcast.cu

@@ -22,6 +22,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
         int ne0, int ne1, int ne2, int ne3,
         int ne0, int ne1, int ne2, int ne3,
         int ne10, int ne11, int ne12, int ne13,
         int ne10, int ne11, int ne12, int ne13,
         /*int s0, */ int s1,  int s2,  int s3,
         /*int s0, */ int s1,  int s2,  int s3,
+        /*int s00,*/ int s01, int s02, int s03,
         /*int s10,*/ int s11, int s12, int s13) {
         /*int s10,*/ int s11, int s12, int s13) {
     const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
     const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
     const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
     const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
@@ -36,9 +37,9 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
     const int i12 = i2 % ne12;
     const int i12 = i2 % ne12;
     const int i13 = i3 % ne13;
     const int i13 = i3 % ne13;
 
 
-    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
-    const size_t i_dst  = i_src0;
+    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
 
 
     const src0_t * src0_row = src0 + i_src0;
     const src0_t * src0_row = src0 + i_src0;
     const src1_t * src1_row = src1 + i_src1;
     const src1_t * src1_row = src1 + i_src1;
@@ -55,6 +56,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
         int ne0, int ne1, int ne2, int ne3,
         int ne0, int ne1, int ne2, int ne3,
         int ne10, int ne11, int ne12, int ne13,
         int ne10, int ne11, int ne12, int ne13,
         /*int s0, */ int s1,  int s2,  int s3,
         /*int s0, */ int s1,  int s2,  int s3,
+        /*int s00,*/ int s01, int s02, int s03,
         /*int s10,*/ int s11, int s12, int s13) {
         /*int s10,*/ int s11, int s12, int s13) {
 
 
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -72,9 +74,9 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
     const int i12 = i2 % ne12;
     const int i12 = i2 % ne12;
     const int i13 = i3 % ne13;
     const int i13 = i3 % ne13;
 
 
-    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
-    const size_t i_dst  = i_src0;
+    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
 
 
     const src0_t * src0_row = src0 + i_src0;
     const src0_t * src0_row = src0 + i_src0;
     const src1_t * src1_row = src1 + i_src1;
     const src1_t * src1_row = src1 + i_src1;
@@ -101,10 +103,14 @@ struct bin_bcast_cuda {
         int nr[4] = { nr0, nr1, nr2, nr3 };
         int nr[4] = { nr0, nr1, nr2, nr3 };
 
 
         // collapse dimensions until first broadcast dimension
         // collapse dimensions until first broadcast dimension
-        int64_t cne0[] = {ne0, ne1, ne2, ne3};
+        int64_t cne[] = {ne0, ne1, ne2, ne3};
+        int64_t cne0[] = {ne00, ne01, ne02, ne03};
         int64_t cne1[] = {ne10, ne11, ne12, ne13};
         int64_t cne1[] = {ne10, ne11, ne12, ne13};
-        size_t cnb0[] = {nb0, nb1, nb2, nb3};
+
+        size_t cnb[] = {nb0, nb1, nb2, nb3};
+        size_t cnb0[] = {nb00, nb01, nb02, nb03};
         size_t cnb1[] = {nb10, nb11, nb12, nb13};
         size_t cnb1[] = {nb10, nb11, nb12, nb13};
+
         auto collapse = [](int64_t cne[]) {
         auto collapse = [](int64_t cne[]) {
             cne[0] *= cne[1];
             cne[0] *= cne[1];
             cne[1] = cne[2];
             cne[1] = cne[2];
@@ -118,32 +124,47 @@ struct bin_bcast_cuda {
             cnb[3] *= cne[3];
             cnb[3] *= cne[3];
         };
         };
 
 
-        for (int i = 0; i < 4; i++) {
-            if (nr[i] != 1) {
-                break;
-            }
-            if (i > 0) {
-                collapse_nb(cnb0, cne0);
-                collapse_nb(cnb1, cne1);
-                collapse(cne0);
-                collapse(cne1);
+        if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+            for (int i = 0; i < 4; i++) {
+                if (nr[i] != 1) {
+                    break;
+                }
+                if (i > 0) {
+                    collapse_nb(cnb, cne);
+                    collapse_nb(cnb0, cne0);
+                    collapse_nb(cnb1, cne1);
+                    collapse(cne);
+                    collapse(cne0);
+                    collapse(cne1);
+                }
             }
             }
         }
         }
+
         {
         {
-            int64_t ne0 = cne0[0];
-            int64_t ne1 = cne0[1];
-            int64_t ne2 = cne0[2];
-            int64_t ne3 = cne0[3];
+            int64_t ne0 = cne[0];
+            int64_t ne1 = cne[1];
+            int64_t ne2 = cne[2];
+            int64_t ne3 = cne[3];
+
+            //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
+            //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
+            //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
+            //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
 
 
             int64_t ne10 = cne1[0];
             int64_t ne10 = cne1[0];
             int64_t ne11 = cne1[1];
             int64_t ne11 = cne1[1];
             int64_t ne12 = cne1[2];
             int64_t ne12 = cne1[2];
             int64_t ne13 = cne1[3];
             int64_t ne13 = cne1[3];
 
 
-            size_t nb0 = cnb0[0];
-            size_t nb1 = cnb0[1];
-            size_t nb2 = cnb0[2];
-            size_t nb3 = cnb0[3];
+            size_t nb0 = cnb[0];
+            size_t nb1 = cnb[1];
+            size_t nb2 = cnb[2];
+            size_t nb3 = cnb[3];
+
+            size_t nb00 = cnb0[0];
+            size_t nb01 = cnb0[1];
+            size_t nb02 = cnb0[2];
+            size_t nb03 = cnb0[3];
 
 
             size_t nb10 = cnb1[0];
             size_t nb10 = cnb1[0];
             size_t nb11 = cnb1[1];
             size_t nb11 = cnb1[1];
@@ -160,7 +181,28 @@ struct bin_bcast_cuda {
             size_t s12 = nb12 / sizeof(src1_t);
             size_t s12 = nb12 / sizeof(src1_t);
             size_t s13 = nb13 / sizeof(src1_t);
             size_t s13 = nb13 / sizeof(src1_t);
 
 
+            size_t s00 = nb00 / sizeof(src0_t);
+            size_t s01 = nb01 / sizeof(src0_t);
+            size_t s02 = nb02 / sizeof(src0_t);
+            size_t s03 = nb03 / sizeof(src0_t);
+
+            GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+            GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+            GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
             GGML_ASSERT(s0 == 1);
             GGML_ASSERT(s0 == 1);
+            GGML_ASSERT(s00 == 1);
             GGML_ASSERT(s10 == 1);
             GGML_ASSERT(s10 == 1);
 
 
             const int block_size = 128;
             const int block_size = 128;
@@ -179,13 +221,14 @@ struct bin_bcast_cuda {
             );
             );
 
 
             if (block_nums.z > 65535) {
             if (block_nums.z > 65535) {
-                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
+                // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
                 int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
                 int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
                 k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
                 k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
                     src0_dd, src1_dd, dst_dd,
                     src0_dd, src1_dd, dst_dd,
                     ne0, ne1, ne2, ne3,
                     ne0, ne1, ne2, ne3,
                     ne10, ne11, ne12, ne13,
                     ne10, ne11, ne12, ne13,
                     /* s0, */ s1, s2, s3,
                     /* s0, */ s1, s2, s3,
+                    /* s00, */ s01, s02, s03,
                     /* s10, */ s11, s12, s13);
                     /* s10, */ s11, s12, s13);
             } else {
             } else {
                 k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
                 k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
@@ -193,6 +236,7 @@ struct bin_bcast_cuda {
                     ne0, ne1, ne2, ne3,
                     ne0, ne1, ne2, ne3,
                     ne10, ne11, ne12, ne13,
                     ne10, ne11, ne12, ne13,
                     /* s0, */ s1, s2, s3,
                     /* s0, */ s1, s2, s3,
+                    /* s00, */ s01, s02, s03,
                     /* s10, */ s11, s12, s13);
                     /* s10, */ s11, s12, s13);
             }
             }
         }
         }

+ 2 - 0
ggml-cuda/convert.cu

@@ -45,6 +45,8 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
         vals[ix] = x0[ix];
         vals[ix] = x0[ix];
     }
     }
 
 
+    __syncthreads();
+
 #pragma unroll
 #pragma unroll
     for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
     for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
         if (need_check && i0 + iy + 2*threadIdx.x >= k) {
         if (need_check && i0 + iy + 2*threadIdx.x >= k) {

+ 61 - 68
ggml-metal.m

@@ -1732,15 +1732,10 @@ static enum ggml_status ggml_metal_graph_compute(
                     } break;
                     } break;
                 case GGML_OP_MUL_MAT_ID:
                 case GGML_OP_MUL_MAT_ID:
                     {
                     {
-                        //GGML_ASSERT(ne00 == ne10);
-                        //GGML_ASSERT(ne03 == ne13);
                         const int n_as = src0->ne[2];
                         const int n_as = src0->ne[2];
 
 
-                        // max size of the src1ids array in the kernel shared buffer
-                        GGML_ASSERT(ne11 <= 4096);
-
                         // src2 = ids
                         // src2 = ids
-                        const int64_t  ne20 = src2->ne[0]; GGML_UNUSED(ne20);
+                        const int64_t  ne20 = src2->ne[0];
                         const int64_t  ne21 = src2->ne[1];
                         const int64_t  ne21 = src2->ne[1];
                         const int64_t  ne22 = src2->ne[2]; GGML_UNUSED(ne22);
                         const int64_t  ne22 = src2->ne[2]; GGML_UNUSED(ne22);
                         const int64_t  ne23 = src2->ne[3]; GGML_UNUSED(ne23);
                         const int64_t  ne23 = src2->ne[3]; GGML_UNUSED(ne23);
@@ -1761,15 +1756,13 @@ static enum ggml_status ggml_metal_graph_compute(
 
 
                         // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                         // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                         // to the matrix-vector kernel
                         // to the matrix-vector kernel
-                        int ne11_mm_min = n_as;
-
-                        const int idx = ((int32_t *) dst->op_params)[0];
+                        // ne20 = n_used_experts
+                        // ne21 = n_rows
+                        const int dst_rows = ne20*ne21;
+                        const int dst_rows_min = n_as;
 
 
-                        // batch size
-                        GGML_ASSERT(ne21 == ne11); // ?
-                        GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
-                        const uint r2 = 1;
-                        const uint r3 = 1;
+                        // max size of the rowids array in the kernel shared buffer
+                        GGML_ASSERT(dst_rows <= 2048);
 
 
                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1779,7 +1772,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         // !!!
                         // !!!
                         if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
                         if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
                             ne00 % 32 == 0 && ne00 >= 64 &&
                             ne00 % 32 == 0 && ne00 >= 64 &&
-                            ne11 > ne11_mm_min) {
+                            dst_rows > dst_rows_min) {
 
 
                             // some Metal matrix data types require aligned pointers
                             // some Metal matrix data types require aligned pointers
                             // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
                             // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -1821,26 +1814,26 @@ static enum ggml_status ggml_metal_graph_compute(
                             [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                             [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
                             [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
                             [encoder setBuffer:id_src2 offset:offs_src2    atIndex:3];
                             [encoder setBuffer:id_src2 offset:offs_src2    atIndex:3];
-                            [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:4];
-                            [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:5];
-                            [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:6];
-                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
-                            [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
-                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:9];
-                            [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:10];
-                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
-                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
-                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:13];
-                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:14];
-                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:15];
-                            [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:16];
-                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:17];
-                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:18];
-                            [encoder setBytes:&idx     length:sizeof(idx)  atIndex:19];
-
-                            [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
-
-                            [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
+                            [encoder setBytes:&ne21    length:sizeof(ne21) atIndex:5];
+                            [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
+                            [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:7];
+                            [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:8];
+                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:9];
+                            [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:10];
+                            [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:11];
+                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:12];
+                            [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:13];
+                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:14];
+                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:15];
+                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:16];
+                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:17];
+                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:18];
+                            [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:19];
+
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                         } else {
                         } else {
                             int nth0 = 32;
                             int nth0 = 32;
                             int nth1 = 1;
                             int nth1 = 1;
@@ -1993,72 +1986,72 @@ static enum ggml_status ggml_metal_graph_compute(
                                 GGML_ASSERT(ne00 >= nth0*nth1);
                                 GGML_ASSERT(ne00 >= nth0*nth1);
                             }
                             }
 
 
-                            const int64_t _ne1 = 1; // kernels needs a reference in constant memory
-
                             [encoder setComputePipelineState:pipeline];
                             [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
                             [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
                             [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
-                            [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
-                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
-                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
-                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
-                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
-                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
-                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
-                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
-                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
-                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:20];
-                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:21];
-                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:22];
-                            [encoder setBytes:&idx  length:sizeof(idx)  atIndex:23];
+                            [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+                            [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+                            [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:20];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:21];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:22];
+
+                            const int64_t _ne1 = 1;
+                            const int tgz = dst_rows;
 
 
                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
                                 src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
                                 src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
                                 src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
                                 src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             }
                             else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
                             else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
                                 const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
                                 const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             }
                             else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
                             else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
                                 const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
                                 const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             }
                             else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
                             else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
                                 const int mem_size = 32*sizeof(float);
                                 const int mem_size = 32*sizeof(float);
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             }
                             else if (src0t == GGML_TYPE_Q4_K) {
                             else if (src0t == GGML_TYPE_Q4_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             }
                             else if (src0t == GGML_TYPE_Q3_K) {
                             else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
 #ifdef GGML_QKK_64
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #else
 #else
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #endif
 #endif
                             }
                             }
                             else if (src0t == GGML_TYPE_Q5_K) {
                             else if (src0t == GGML_TYPE_Q5_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             }
                             else if (src0t == GGML_TYPE_Q6_K) {
                             else if (src0t == GGML_TYPE_Q6_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             } else {
                             } else {
-                                const int64_t ny = (_ne1 + nrows - 1)/nrows;
-                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             }
                         }
                         }
                     } break;
                     } break;

File diff suppressed because it is too large
+ 362 - 420
ggml-metal.metal


+ 1 - 1
ggml-sycl.cpp

@@ -17752,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
 
 
 GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
 GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
     const int min_batch_size = 32;
     const int min_batch_size = 32;
-    return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
+    return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
     GGML_UNUSED(backend);
     GGML_UNUSED(backend);
 }
 }
 
 

+ 62 - 61
ggml.c

@@ -4578,21 +4578,32 @@ void ggml_mul_mat_set_prec(
 
 
 // ggml_mul_mat_id
 // ggml_mul_mat_id
 
 
-// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
-//       this will allow computing all the used experts in a single matrix multiplication
+/*
+    c = ggml_mul_mat_id(ctx, as, b, ids);
+
+    as  -> [cols, rows, n_expert]
+    ids -> [n_experts_used, n_tokens] (i32)
+    b   -> [cols, n_expert_used, n_tokens]
+    c   -> [cols, n_expert_used, n_tokens]
+
+    in b, n_experts_used can be broadcasted to match the n_expert_used of ids
+
+    c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
+*/
 struct ggml_tensor * ggml_mul_mat_id(
 struct ggml_tensor * ggml_mul_mat_id(
         struct ggml_context * ctx,
         struct ggml_context * ctx,
         struct ggml_tensor  * as,
         struct ggml_tensor  * as,
-        struct ggml_tensor  * ids,
-        int                   id,
-        struct ggml_tensor  * b) {
-
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * ids) {
+    GGML_ASSERT(!ggml_is_transposed(as));
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+    GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
+    GGML_ASSERT(b->ne[3] == 1); // b is 3d
     GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
     GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
-    GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
-    GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
-    GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
+    GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
     GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
     GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
+    GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
 
 
     bool is_node = false;
     bool is_node = false;
 
 
@@ -4600,11 +4611,9 @@ struct ggml_tensor * ggml_mul_mat_id(
         is_node = true;
         is_node = true;
     }
     }
 
 
-    const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+    const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
 
-    ggml_set_op_params_i32(result, 0, id);
-
     result->op   = GGML_OP_MUL_MAT_ID;
     result->op   = GGML_OP_MUL_MAT_ID;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = as;
     result->src[0] = as;
@@ -11009,11 +11018,6 @@ static void ggml_compute_forward_mul_mat_id(
     enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
     enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
     ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
     ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
 
 
-    GGML_ASSERT(ne0 == ne01);
-    GGML_ASSERT(ne1 == ne11);
-    GGML_ASSERT(ne2 == ne12);
-    GGML_ASSERT(ne3 == ne13);
-
     // we don't support permuted src0 or src1
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == ggml_type_size(type));
     GGML_ASSERT(nb00 == ggml_type_size(type));
     GGML_ASSERT(nb10 == ggml_type_size(src1->type));
     GGML_ASSERT(nb10 == ggml_type_size(src1->type));
@@ -11024,22 +11028,21 @@ static void ggml_compute_forward_mul_mat_id(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
     GGML_ASSERT(nb2 <= nb3);
 
 
-    // broadcast is not supported with mmid
-    assert(ne12 == 1);
-    assert(ne13 == 1);
-
     // row groups
     // row groups
-    const int id   = ggml_get_op_params_i32(dst, 0);
-    const int n_as = src0->ne[2];
+    const int n_ids = ids->ne[0]; // n_expert_used
+    const int n_as  = ne02;       // n_expert
 
 
     char * wdata_src1_end = (src1->type == vec_dot_type) ?
     char * wdata_src1_end = (src1->type == vec_dot_type) ?
             (char *) params->wdata :
             (char *) params->wdata :
             (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
             (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
 
 
-    int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
-    int64_t * matrix_rows       = matrix_row_counts + n_as;     // [n_as][ne11]
+    struct mmid_row_mapping {
+        int32_t i1;
+        int32_t i2;
+    };
 
 
-    #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
+    int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+    struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
 
 
    if (params->type == GGML_TASK_TYPE_INIT) {
    if (params->type == GGML_TASK_TYPE_INIT) {
         if (ith != 0) {
         if (ith != 0) {
@@ -11065,13 +11068,18 @@ static void ggml_compute_forward_mul_mat_id(
         // initialize matrix_row_counts
         // initialize matrix_row_counts
         memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
         memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
 
 
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
+
         // group rows by src0 matrix
         // group rows by src0 matrix
-        for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-            const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
+        for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
+            for (int id = 0; id < n_ids; ++id) {
+                const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
+
+                assert(i02 >= 0 && i02 < n_as);
 
 
-            GGML_ASSERT(row_id >= 0 && row_id < n_as);
-            MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
-            matrix_row_counts[row_id] += 1;
+                MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
+                matrix_row_counts[i02] += 1;
+            }
         }
         }
 
 
         return;
         return;
@@ -11089,15 +11097,13 @@ static void ggml_compute_forward_mul_mat_id(
             continue;
             continue;
         }
         }
 
 
-        size_t src0_offset = cur_a*src0->nb[2];
+        const char * src0_cur = (const char *) src0->data + cur_a*nb02;
 
 
         const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
         const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
         const size_t row_size = ggml_row_size(vec_dot_type, ne10);
         const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
 
-        const int64_t nr0 = ne01;           // src0 rows
-        const int64_t nr1 = cne1*ne12*ne13; // src1 rows
-
-        //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
+        const int64_t nr0 = ne01; // src0 rows
+        const int64_t nr1 = cne1; // src1 rows
 
 
         // distribute the thread work across the inner or outer loop based on which one is larger
         // distribute the thread work across the inner or outer loop based on which one is larger
 
 
@@ -11116,13 +11122,11 @@ static void ggml_compute_forward_mul_mat_id(
         const int64_t ir110 = dr1*ith1;
         const int64_t ir110 = dr1*ith1;
         const int64_t ir111 = MIN(ir110 + dr1, nr1);
         const int64_t ir111 = MIN(ir110 + dr1, nr1);
 
 
-        //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
-
         // threads with no work simply yield (not sure if it helps)
         // threads with no work simply yield (not sure if it helps)
-        if (ir010 >= ir011 || ir110 >= ir111) {
-            sched_yield();
-            continue;
-        }
+        //if (ir010 >= ir011 || ir110 >= ir111) {
+        //    sched_yield();
+        //    continue;
+        //}
 
 
         // block-tiling attempt
         // block-tiling attempt
         const int64_t blck_0 = 16;
         const int64_t blck_0 = 16;
@@ -11134,20 +11138,16 @@ static void ggml_compute_forward_mul_mat_id(
         for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
         for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
             for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
             for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
                 for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
                 for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
-                    const int64_t  i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
-                    const int64_t  i12 = (ir1 - i13*ne12*cne1)/cne1;
-                    const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
-                    const int64_t  i11 = MMID_MATRIX_ROW(cur_a, _i11);
+                    const int64_t _i12 = ir1; // logical row index for this expert
 
 
-                    // broadcast src0 into src1
-                    //const int64_t i03 = i13/r3;
-                    //const int64_t i02 = i12/r2;
+                    struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
+                    const int id       = row_mapping.i1; // selected expert index
 
 
-                    const int64_t i1 = i11;
-                    const int64_t i2 = i12;
-                    const int64_t i3 = i13;
+                    const int64_t  i11 = id % ne11;
+                    const int64_t  i12 = row_mapping.i2; // row index in src1
 
 
-                    const char * src0_row = (const char *) src0->data + src0_offset;
+                    const int64_t  i1 = id;  // selected expert index
+                    const int64_t  i2 = i12; // row
 
 
                     // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
                     // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
                     //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
                     //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
@@ -11155,25 +11155,26 @@ static void ggml_compute_forward_mul_mat_id(
                     // TODO: this is a bit of a hack, we should probably have a better way to handle this
                     // TODO: this is a bit of a hack, we should probably have a better way to handle this
                     const char * src1_col = (const char *) wdata +
                     const char * src1_col = (const char *) wdata +
                         (src1_cont || src1->type != vec_dot_type
                         (src1_cont || src1->type != vec_dot_type
-                        ? (i11      + i12*ne11 + i13*ne12*ne11)*row_size
-                        : (i11*nb11 + i12*nb12 + i13*nb13));
+                        ? (i11      + i12*ne11)*row_size
+                        : (i11*nb11 + i12*nb12));
 
 
-                    float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
+                    float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
 
 
                     //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
                     //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
                     //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
                     //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
                     //}
                     //}
 
 
                     for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
                     for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
-                        vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1);
+                        vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
                     }
                     }
+
                     memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
                     memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
                 }
                 }
             }
             }
         }
         }
     }
     }
 
 
-    #undef MMID_MATRIX_ROW
+#undef MMID_MATRIX_ROW
 }
 }
 
 
 // ggml_compute_forward_out_prod
 // ggml_compute_forward_out_prod
@@ -18512,7 +18513,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
                     const int n_as = src0->ne[2];
                     const int n_as = src0->ne[2];
                     cur += GGML_PAD(cur, sizeof(int64_t));       // align
                     cur += GGML_PAD(cur, sizeof(int64_t));       // align
                     cur += n_as * sizeof(int64_t);               // matrix_row_counts
                     cur += n_as * sizeof(int64_t);               // matrix_row_counts
-                    cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
+                    cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
                 } break;
                 } break;
             case GGML_OP_OUT_PROD:
             case GGML_OP_OUT_PROD:
                 {
                 {
@@ -20938,12 +20939,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
 
 
             ok = ok && cur != NULL;
             ok = ok && cur != NULL;
 
 
-            ggml_set_name(cur, ctx->infos[i].name.data);
-
             if (!ok) {
             if (!ok) {
                 break;
                 break;
             }
             }
 
 
+            ggml_set_name(cur, ctx->infos[i].name.data);
+
             // point the data member to the appropriate location in the binary blob using the tensor infos
             // point the data member to the appropriate location in the binary blob using the tensor infos
             if (!params.no_alloc) {
             if (!params.no_alloc) {
               //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
               //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file

+ 2 - 4
ggml.h

@@ -1161,13 +1161,11 @@ extern "C" {
             enum ggml_prec       prec);
             enum ggml_prec       prec);
 
 
     // indirect matrix multiplication
     // indirect matrix multiplication
-    //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
     GGML_API struct ggml_tensor * ggml_mul_mat_id(
     GGML_API struct ggml_tensor * ggml_mul_mat_id(
             struct ggml_context * ctx,
             struct ggml_context * ctx,
             struct ggml_tensor  * as,
             struct ggml_tensor  * as,
-            struct ggml_tensor  * ids,
-            int                   id,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * ids);
 
 
     // A: m columns, n rows,
     // A: m columns, n rows,
     // B: p columns, n rows,
     // B: p columns, n rows,

+ 138 - 85
llama.cpp

@@ -4495,6 +4495,13 @@ static bool llm_load_tensors(
 
 
     auto & hparams = model.hparams;
     auto & hparams = model.hparams;
 
 
+#ifdef GGML_USE_SYCL
+    // disable MoE with SYCL until mul_mat_id is updated
+    if (hparams.n_expert > 0) {
+        n_gpu_layers = 0;
+    }
+#endif
+
     model.split_mode   = split_mode;
     model.split_mode   = split_mode;
     model.main_gpu     = main_gpu;
     model.main_gpu     = main_gpu;
     model.n_gpu_layers = n_gpu_layers;
     model.n_gpu_layers = n_gpu_layers;
@@ -6099,6 +6106,100 @@ static struct ggml_tensor * llm_build_ffn(
     return cur;
     return cur;
 }
 }
 
 
+static struct ggml_tensor * llm_build_moe_ffn(
+        struct ggml_context * ctx,
+         struct ggml_tensor * cur,
+         struct ggml_tensor * gate_inp,
+         struct ggml_tensor * up_exps,
+         struct ggml_tensor * gate_exps,
+         struct ggml_tensor * down_exps,
+                    int64_t   n_expert,
+                    int64_t   n_expert_used,
+            llm_ffn_op_type   type_op,
+                       bool   norm_w,
+         const llm_build_cb & cb,
+                        int   il) {
+    int64_t n_embd = cur->ne[0];
+    int64_t n_tokens = cur->ne[1];
+
+    ggml_tensor * logits = ggml_mul_mat(ctx, gate_inp, cur); // [n_expert, n_tokens]
+    cb(logits, "ffn_moe_logits", il);
+
+    ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
+    cb(probs, "ffn_moe_probs", il);
+
+    // select experts
+    ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
+    cb(selected_experts->src[0], "ffn_moe_argsort", il);
+    cb(selected_experts, "ffn_moe_topk", il);
+
+    ggml_tensor * weights = ggml_get_rows(ctx,
+            ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
+    cb(weights, "ffn_moe_weights", il);
+
+    if (norm_w) {
+        weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
+
+        ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
+        cb(weights_sum, "ffn_moe_weights_sum", il);
+
+        weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
+        cb(weights, "ffn_moe_weights_norm", il);
+
+        weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
+    }
+
+    cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
+    ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    cb(up, "ffn_moe_up", il);
+
+    ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    cb(gate, "ffn_moe_gate", il);
+
+    switch (type_op) {
+        case LLM_FFN_SILU:
+            {
+                gate = ggml_silu(ctx, gate);
+                cb(gate, "ffn_moe_silu", il);
+            } break;
+        case LLM_FFN_GELU:
+            {
+                gate = ggml_gelu(ctx, gate);
+                cb(gate, "ffn_moe_gelu", il);
+            } break;
+        default:
+            GGML_ASSERT(false);
+    }
+
+    ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
+    cb(par, "ffn_moe_gate_par", il);
+
+    ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
+    cb(experts, "ffn_moe_down", il);
+
+    experts = ggml_mul(ctx, experts, weights);
+
+    // aggregate experts
+    ggml_tensor * moe_out = nullptr;
+    for (int i = 0; i < n_expert_used; ++i) {
+        ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
+                experts->nb[2], i*experts->nb[1]);
+
+        if (i == 0) {
+            moe_out = cur_expert;
+        } else {
+            moe_out = ggml_add(ctx, moe_out, cur_expert);
+        }
+    }
+
+    if (n_expert_used == 1) {
+        // avoid returning a non-contiguous tensor
+        moe_out = ggml_cont(ctx, moe_out);
+    }
+
+    return moe_out;
+}
+
 // if max_alibi_bias > 0 then apply ALiBi
 // if max_alibi_bias > 0 then apply ALiBi
 static struct ggml_tensor * llm_build_kqv(
 static struct ggml_tensor * llm_build_kqv(
         struct ggml_context * ctx,
         struct ggml_context * ctx,
@@ -6642,7 +6743,15 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
                 cb(cur, "ffn_norm", il);
 
 
-                cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il);
+                cur = llm_build_moe_ffn(ctx0, cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        cb, il);
+                cb(cur, "ffn_moe_out", il);
             }
             }
 
 
             cur = ggml_add(ctx0, cur, ffn_inp);
             cur = ggml_add(ctx0, cur, ffn_inp);
@@ -6674,80 +6783,6 @@ struct llm_build_context {
         return gf;
         return gf;
     }
     }
 
 
-    // REVIEW: will be replaced by https://github.com/ggerganov/llama.cpp/pull/6505
-    ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, bool norm_w, int il) {
-        ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
-        cb(logits, "ffn_moe_logits", il);
-
-        ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
-        cb(probs, "ffn_moe_probs", il);
-
-        // select experts
-        ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
-        cb(selected_experts->src[0], "ffn_moe_argsort", il);
-
-        ggml_tensor * weights = ggml_get_rows(ctx0,
-                                              ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
-        cb(weights, "ffn_moe_weights", il);
-
-        weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
-
-        if (norm_w) {
-            ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
-            cb(weights_sum, "ffn_moe_weights_sum", il);
-
-            weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
-            cb(weights, "ffn_moe_weights_norm", il);
-        }
-
-        // compute expert outputs
-        ggml_tensor * moe_out = nullptr;
-
-        for (int i = 0; i < n_expert_used; ++i) {
-            ggml_tensor * cur_expert;
-
-            ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
-            cb(cur_up, "ffn_moe_up", il);
-
-            ggml_tensor * gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
-            cb(gate, "ffn_moe_gate", il);
-
-            switch (type_op) {
-                case LLM_FFN_SILU:
-                {
-                    gate = ggml_silu(ctx0, gate);
-                    cb(gate, "ffn_moe_silu", il);
-                } break;
-                case LLM_FFN_GELU:
-                {
-                    gate = ggml_gelu(ctx0, gate);
-                    cb(gate, "ffn_moe_gelu", il);
-                } break;
-                default:
-                    GGML_ASSERT(false);
-            }
-
-            cur_expert = ggml_mul(ctx0, cur_up, gate);
-            cb(cur_expert, "ffn_moe_gate_par", il);
-
-            cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
-            cb(cur_expert, "ffn_moe_down", il);
-
-            cur_expert = ggml_mul(ctx0, cur_expert,
-                                  ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
-            cb(cur_expert, "ffn_moe_weighted", il);
-
-            if (i == 0) {
-                moe_out = cur_expert;
-            } else {
-                moe_out = ggml_add(ctx0, moe_out, cur_expert);
-                cb(moe_out, "ffn_moe_out", il);
-            }
-        }
-
-        return moe_out;
-    }
-
     struct ggml_cgraph * build_baichuan() {
     struct ggml_cgraph * build_baichuan() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
 
@@ -7195,7 +7230,15 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
             cb(cur, "ffn_norm", il);
 
 
-            cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, true, il);
+            cur = llm_build_moe_ffn(ctx0, cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    n_expert, n_expert_used,
+                    LLM_FFN_GELU, true,
+                    cb, il);
+            cb(cur, "ffn_moe_out", il);
 
 
             // Grok
             // Grok
             // if layer_out_norm is present then apply it before adding the input
             // if layer_out_norm is present then apply it before adding the input
@@ -7207,7 +7250,6 @@ struct llm_build_context {
                 cb(cur, "layer_out_norm", il);
                 cb(cur, "layer_out_norm", il);
             }
             }
 
 
-
             cur = ggml_add(ctx0, cur, ffn_inp);
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
             cb(cur, "ffn_out", il);
 
 
@@ -7331,7 +7373,15 @@ struct llm_build_context {
                                  LLM_NORM, cb, il);
                                  LLM_NORM, cb, il);
             cb(cur, "attn_out_norm", il);
             cb(cur, "attn_out_norm", il);
 
 
-            cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il);
+            cur = llm_build_moe_ffn(ctx0, cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, true,
+                    cb, il);
+            cb(cur, "ffn_moe_out", il);
 
 
             cur = ggml_add(ctx0, cur, ffn_inp);
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
             cb(cur, "ffn_out", il);
@@ -8502,12 +8552,6 @@ struct llm_build_context {
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
                 cb(Vcur, "Vcur", il);
 
 
-                // these nodes are added to the graph together so that they are not reordered
-                // by doing so, the number of splits in the graph is reduced
-                ggml_build_forward_expand(gf, Qcur);
-                ggml_build_forward_expand(gf, Kcur);
-                ggml_build_forward_expand(gf, Vcur);
-
                 Qcur = ggml_rope_custom(
                 Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
@@ -8658,7 +8702,16 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
             cb(cur, "ffn_norm", il);
 
 
-            ggml_tensor * moe_out = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, false, il);
+            ggml_tensor * moe_out =
+                    llm_build_moe_ffn(ctx0, cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, false,
+                        cb, il);
+            cb(cur, "ffn_moe_out", il);
 
 
             // FFN shared expert
             // FFN shared expert
             {
             {

+ 1 - 13
scripts/compare-commits.sh

@@ -12,19 +12,7 @@ bench_args="${@:3}"
 
 
 rm -f llama-bench.sqlite
 rm -f llama-bench.sqlite
 
 
-backend="cpu"
-
-if [[ "$OSTYPE" == "darwin"* ]]; then
-    backend="metal"
-elif command -v nvcc &> /dev/null; then
-    backend="cuda"
-fi
-
-make_opts=""
-
-if [[ "$backend" == "cuda" ]]; then
-    make_opts="LLAMA_CUDA=1"
-fi
+# to test a backend, call the script with the corresponding environment variable (e.g. LLAMA_CUDA=1 ./scripts/compare-commits.sh ...)
 
 
 git checkout $1
 git checkout $1
 make clean && make -j32 $make_opts llama-bench
 make clean && make -j32 $make_opts llama-bench

+ 66 - 21
tests/test-backend-ops.cpp

@@ -101,7 +101,7 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
                     } else if (t->type == GGML_TYPE_I8) {
                     } else if (t->type == GGML_TYPE_I8) {
                         tv.push_back((float)*(int8_t *) &buf[i]);
                         tv.push_back((float)*(int8_t *) &buf[i]);
                     } else if (quantized) {
                     } else if (quantized) {
-                        tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type));
+                        tt.to_float(&buf[i], vq.data(), bs);
                         tv.insert(tv.end(), vq.begin(), vq.end());
                         tv.insert(tv.end(), vq.begin(), vq.end());
                     } else {
                     } else {
                         GGML_ASSERT(false);
                         GGML_ASSERT(false);
@@ -948,14 +948,14 @@ struct test_mul_mat_id : public test_case {
     const ggml_type type_a;
     const ggml_type type_a;
     const ggml_type type_b;
     const ggml_type type_b;
     const int n_mats;
     const int n_mats;
-    const int id;
+    const int n_used;
+    const bool b; // brodcast b matrix
     const int64_t m;
     const int64_t m;
     const int64_t n;
     const int64_t n;
     const int64_t k;
     const int64_t k;
-    const bool v; // view (non-contiguous ids)
 
 
     std::string vars() override {
     std::string vars() override {
-        return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v);
+        return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);
     }
     }
 
 
     double max_nmse_err() override {
     double max_nmse_err() override {
@@ -972,20 +972,22 @@ struct test_mul_mat_id : public test_case {
     }
     }
 
 
     test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
     test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
-            int n_mats = 2, int id = 0,
-            int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false)
-        : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
-            m(m), n(n), k(k), v(v) {}
+            int n_mats = 8, int n_used = 2, bool b = false,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32)
+        : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
+            m(m), n(n), k(k) {
+            GGML_ASSERT(n_used <= n_mats);
+        }
 
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
     ggml_tensor * build_graph(ggml_context * ctx) override {
         // C^T = A * B^T: (k, m) * (k, n) => (m, n)
         // C^T = A * B^T: (k, m) * (k, n) => (m, n)
-        ggml_tensor * mats = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
+        ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
         ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
         ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
-        if (v) {
-            ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
+        if (n_used != n_mats) {
+            ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
         }
         }
-        ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
-        ggml_tensor * out = ggml_mul_mat_id(ctx, mats, ids, v ? id/2 : id, b);
+        ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
+        ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
         return out;
         return out;
     }
     }
 
 
@@ -1611,7 +1613,6 @@ public:
     }
     }
 };
 };
 
 
-
 // Llama
 // Llama
 struct test_llama : public test_llm {
 struct test_llama : public test_llm {
     static constexpr float freq_base = 10000.0f;
     static constexpr float freq_base = 10000.0f;
@@ -1875,6 +1876,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
         GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
     };
     };
 
 
+    const ggml_type base_types[] = {
+        GGML_TYPE_F32, GGML_TYPE_F16,
+        GGML_TYPE_Q4_0,
+        GGML_TYPE_Q4_K,
+        GGML_TYPE_IQ2_XXS
+    };
+
+    const ggml_type other_types[] = {
+        GGML_TYPE_Q4_1,
+        GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+        GGML_TYPE_Q8_0,
+        GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+        GGML_TYPE_Q5_K,
+        GGML_TYPE_Q6_K,
+        GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
+        GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
+        GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+    };
+
     // unary ops
     // unary ops
     for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
     for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
         test_cases.emplace_back(new test_unary((ggml_unary_op) op));
         test_cases.emplace_back(new test_unary((ggml_unary_op) op));
@@ -1983,7 +2003,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
         test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
     }
     }
 
 
-    for (ggml_type type_a : all_types) {
+    for (ggml_type type_a : base_types) {
         for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
         for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1,  1}, {1, 1}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1,  1}, {1, 1}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {1, 1}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {1, 1}));
@@ -2003,6 +2023,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
         }
     }
     }
 
 
+    for (ggml_type type_a : other_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32}) {
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1,  1}, {1, 1}));
+        }
+    }
+
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,  128, { 8,  1}, {1, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,  128, { 8,  1}, {1, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,  128, { 8,  1}, {4, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,  128, { 8,  1}, {4, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,   64, { 8,  1}, {4, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,   64, { 8,  1}, {4, 1}));
@@ -2010,13 +2036,32 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 45, 128, { 8,  1}, {4, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 45, 128, { 8,  1}, {4, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45,  64, { 8,  1}, {4, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45,  64, { 8,  1}, {4, 1}));
 
 
-    for (ggml_type type_a : all_types) {
+    for (ggml_type type_a : base_types) {
         for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
         for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
-            for (int n_mats : {2, 4, 8}) {
-                for (int id = 0; id < n_mats; id++) {
-                    for (bool v : {false, true}) {
-                        test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16,  1, 256, v));
-                        test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
+            for (int n_mats : {4, 8}) {
+                for (int n_used : {1, 2, 4}) {
+                    for (bool b : {false, true}) {
+                        for (int n : {1, 32}) {
+                            int m = 512;
+                            int k = 256;
+                            test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    for (ggml_type type_a : other_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
+            for (int n_mats : {4}) {
+                for (int n_used : {2}) {
+                    for (bool b : {false}) {
+                        for (int n : {1}) {
+                            int m = 512;
+                            int k = 256;
+                            test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
+                        }
                     }
                     }
                 }
                 }
             }
             }

Some files were not shown because too many files changed in this diff