|
|
@@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node(
|
|
|
nth *= 2;
|
|
|
}
|
|
|
|
|
|
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
|
nth = MIN(nth, ne00);
|
|
|
|
|
|
ggml_metal_kargs_sum_rows args = {
|
|
|
@@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node(
|
|
|
nth *= 2;
|
|
|
}
|
|
|
|
|
|
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
|
nth = MIN(nth, ne00/4);
|
|
|
|
|
|
ggml_metal_kargs_rms_norm args = {
|
|
|
@@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node(
|
|
|
nth *= 2;
|
|
|
}
|
|
|
|
|
|
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
|
nth = MIN(nth, ne00/4);
|
|
|
|
|
|
ggml_metal_kargs_l2_norm args = {
|
|
|
@@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node(
|
|
|
nth *= 2;
|
|
|
}
|
|
|
|
|
|
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
|
nth = MIN(nth, ne00/4);
|
|
|
|
|
|
ggml_metal_kargs_norm args = {
|
|
|
@@ -4974,8 +4978,39 @@ static bool ggml_metal_encode_node(
|
|
|
default: GGML_ABORT("not implemented");
|
|
|
}
|
|
|
|
|
|
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
|
|
+
|
|
|
+ // TODO: support
|
|
|
+ //const int32_t nk00 = ne00/ggml_blck_size(dst->type);
|
|
|
+ const int32_t nk00 = ne00;
|
|
|
+
|
|
|
+ int nth = 32; // SIMD width
|
|
|
+
|
|
|
+ while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
|
+ nth *= 2;
|
|
|
+ }
|
|
|
+
|
|
|
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
|
+
|
|
|
+ // when rows are small, we can batch them together in a single threadgroup
|
|
|
+ int nrptg = 1;
|
|
|
+
|
|
|
+ // TODO: relax this constraint in the future
|
|
|
+ if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
|
|
|
+ if (nth > nk00) {
|
|
|
+ nrptg = (nth + nk00 - 1)/nk00;
|
|
|
+ nth = nk00;
|
|
|
+
|
|
|
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
|
+ nrptg--;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ nth = MIN(nth, nk00);
|
|
|
+
|
|
|
ggml_metal_kargs_cpy args = {
|
|
|
- /*.ne00 =*/ ne00,
|
|
|
+ /*.ne00 =*/ nk00,
|
|
|
/*.ne01 =*/ ne01,
|
|
|
/*.ne02 =*/ ne02,
|
|
|
/*.ne03 =*/ ne03,
|
|
|
@@ -4998,11 +5033,7 @@ static bool ggml_metal_encode_node(
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
|
|
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
|
|
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
|
|
-
|
|
|
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
-
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
|
|
} break;
|
|
|
case GGML_OP_SET:
|
|
|
{
|