Просмотр исходного кода

metal : cap threadgroups size of set_rows (#17146)

Georgi Gerganov 2 месяцев назад
Родитель
Сommit
13730c183b
1 измененных файлов с 5 добавлено и 0 удалено
  1. 5 0
      ggml/src/ggml-metal/ggml-metal-ops.cpp

+ 5 - 0
ggml/src/ggml-metal/ggml-metal-ops.cpp

@@ -1036,6 +1036,11 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
 
     nth = std::min(nth, nk0);
 
+    if (nth*nrptg > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+        nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
+        nrptg = 1;
+    }
+
     ggml_metal_kargs_set_rows args = {
         /*.nk0  =*/ nk0,
         /*.ne01 =*/ ne01,