瀏覽代碼

metal : cap threadgroups size of set_rows (#17146)

Georgi Gerganov 2 月之前
父節點
當前提交
13730c183b
共有 1 個文件被更改,包括 5 次插入0 次删除
  1. 5 0
      ggml/src/ggml-metal/ggml-metal-ops.cpp

+ 5 - 0
ggml/src/ggml-metal/ggml-metal-ops.cpp

@@ -1036,6 +1036,11 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
 
     nth = std::min(nth, nk0);
 
+    if (nth*nrptg > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+        nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
+        nrptg = 1;
+    }
+
     ggml_metal_kargs_set_rows args = {
         /*.nk0  =*/ nk0,
         /*.ne01 =*/ ne01,