|
|
@@ -1642,8 +1642,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
// TODO: make this more general
|
|
|
GGML_ASSERT(n_as <= 8);
|
|
|
|
|
|
- // max size of the src1ids array in the kernel stack
|
|
|
- GGML_ASSERT(ne11 <= 512);
|
|
|
+ // max size of the src1ids array in the kernel shared buffer
|
|
|
+ GGML_ASSERT(ne11 <= 4096);
|
|
|
|
|
|
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
|
|
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
|
@@ -1741,7 +1741,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
|
|
}
|
|
|
|
|
|
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
|
+ [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
|
} else {
|