@@ -1036,6 +1036,11 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
nth = std::min(nth, nk0);
+ if (nth*nrptg > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
+ nrptg = 1;
+ }
+
ggml_metal_kargs_set_rows args = {
/*.nk0 =*/ nk0,
/*.ne01 =*/ ne01,