|
@@ -42,6 +42,7 @@ struct ggml_metal_context {
|
|
|
id<MTLComputePipelineState> pipeline_##name
|
|
id<MTLComputePipelineState> pipeline_##name
|
|
|
|
|
|
|
|
GGML_METAL_DECL_KERNEL(add);
|
|
GGML_METAL_DECL_KERNEL(add);
|
|
|
|
|
+ GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
|
|
GGML_METAL_DECL_KERNEL(mul);
|
|
GGML_METAL_DECL_KERNEL(mul);
|
|
|
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
|
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
|
|
GGML_METAL_DECL_KERNEL(scale);
|
|
GGML_METAL_DECL_KERNEL(scale);
|
|
@@ -157,6 +158,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
|
|
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
|
|
|
|
|
|
|
|
GGML_METAL_ADD_KERNEL(add);
|
|
GGML_METAL_ADD_KERNEL(add);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(add_row);
|
|
|
GGML_METAL_ADD_KERNEL(mul);
|
|
GGML_METAL_ADD_KERNEL(mul);
|
|
|
GGML_METAL_ADD_KERNEL(mul_row);
|
|
GGML_METAL_ADD_KERNEL(mul_row);
|
|
|
GGML_METAL_ADD_KERNEL(scale);
|
|
GGML_METAL_ADD_KERNEL(scale);
|
|
@@ -464,10 +466,16 @@ void ggml_metal_graph_compute(
|
|
|
encoder = [command_buffer computeCommandEncoder];
|
|
encoder = [command_buffer computeCommandEncoder];
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- [encoder setComputePipelineState:ctx->pipeline_add];
|
|
|
|
|
|
|
+ if (ggml_nelements(src1) == ne10) {
|
|
|
|
|
+ // src1 is a row
|
|
|
|
|
+ [encoder setComputePipelineState:ctx->pipeline_add_row];
|
|
|
|
|
+ } else {
|
|
|
|
|
+ [encoder setComputePipelineState:ctx->pipeline_add];
|
|
|
|
|
+ }
|
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
|
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
|
|
|
|
|
|
const int64_t n = ggml_nelements(dst);
|
|
const int64_t n = ggml_nelements(dst);
|
|
|
|
|
|
|
@@ -919,7 +927,9 @@ void ggml_metal_graph_compute(
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case GGML_OP_DUP:
|
|
|
case GGML_OP_CPY:
|
|
case GGML_OP_CPY:
|
|
|
|
|
+ case GGML_OP_CONT:
|
|
|
{
|
|
{
|
|
|
if (encoder == nil) {
|
|
if (encoder == nil) {
|
|
|
encoder = [command_buffer computeCommandEncoder];
|
|
encoder = [command_buffer computeCommandEncoder];
|