|
|
@@ -3,6 +3,7 @@
|
|
|
#import "ggml-impl.h"
|
|
|
#import "ggml-backend-impl.h"
|
|
|
#import "ggml-metal-impl.h"
|
|
|
+#import "ggml-metal-common.h"
|
|
|
|
|
|
#import <Foundation/Foundation.h>
|
|
|
|
|
|
@@ -61,8 +62,11 @@ static struct ggml_backend_metal_device_context {
|
|
|
bool has_bfloat;
|
|
|
bool use_bfloat;
|
|
|
bool use_fusion;
|
|
|
+ bool use_concurrency;
|
|
|
bool use_shared_buffers;
|
|
|
+ bool use_graph_optimize;
|
|
|
|
|
|
+ int debug_graph;
|
|
|
int debug_fusion;
|
|
|
|
|
|
// how many times a given op was fused
|
|
|
@@ -83,7 +87,10 @@ static struct ggml_backend_metal_device_context {
|
|
|
/*.has_bfloat =*/ false,
|
|
|
/*.use_bfloat =*/ false,
|
|
|
/*.use_fusion =*/ true,
|
|
|
+ /*.use_concurrency =*/ true,
|
|
|
/*.use_shared_buffers =*/ true,
|
|
|
+ /*.use_graph_optimize =*/ true,
|
|
|
+ /*.debug_graph =*/ 0,
|
|
|
/*.debug_fusion =*/ 0,
|
|
|
/*.fuse_cnt =*/ { 0 },
|
|
|
/*.max_size =*/ 0,
|
|
|
@@ -124,7 +131,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
|
#else
|
|
|
ctx->use_bfloat = false;
|
|
|
#endif
|
|
|
- ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
|
|
+
|
|
|
+ ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
|
|
+ ctx->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
|
|
|
+
|
|
|
+ {
|
|
|
+ const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
|
|
|
+ ctx->debug_graph = val ? atoi(val) : 0;
|
|
|
+ }
|
|
|
|
|
|
{
|
|
|
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
|
|
@@ -137,6 +151,12 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
|
ctx->use_shared_buffers = false;
|
|
|
}
|
|
|
|
|
|
+ ctx->use_graph_optimize = true;
|
|
|
+
|
|
|
+ if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
|
|
|
+ ctx->use_graph_optimize = false;
|
|
|
+ }
|
|
|
+
|
|
|
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
|
|
|
|
|
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
|
|
@@ -628,7 +648,7 @@ static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
|
|
|
@end
|
|
|
|
|
|
//
|
|
|
-// ggml_metal_mem_pool
|
|
|
+// ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE]
|
|
|
//
|
|
|
|
|
|
struct ggml_metal_mem_pool {
|
|
|
@@ -791,6 +811,9 @@ struct ggml_metal_command_buffer {
|
|
|
|
|
|
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
|
|
|
struct ggml_metal_mem_pool * mem_pool;
|
|
|
+
|
|
|
+ // used to enable concurrent execution of ops in the command buffers
|
|
|
+ struct ggml_mem_ranges * mem_ranges;
|
|
|
};
|
|
|
|
|
|
struct ggml_backend_metal_context {
|
|
|
@@ -1091,7 +1114,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
|
|
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
|
|
|
GGML_LOG_INFO("%s: use fusion = %s\n", __func__, ctx_dev->use_fusion ? "true" : "false");
|
|
|
+ GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, ctx_dev->use_concurrency ? "true" : "false");
|
|
|
GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, ctx_dev->use_shared_buffers ? "true" : "false");
|
|
|
+ GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, ctx_dev->use_graph_optimize ? "true" : "false");
|
|
|
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
|
|
|
|
|
ctx->capture_next_compute = false;
|
|
|
@@ -1105,6 +1130,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
|
|
|
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
|
|
|
ctx->cmd_bufs[i].mem_pool->device = device;
|
|
|
+
|
|
|
+ if (ctx_dev->use_concurrency) {
|
|
|
+ ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
ctx->cmd_bufs_ext = [[NSMutableArray alloc] init];
|
|
|
@@ -1715,6 +1744,10 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
|
|
}
|
|
|
|
|
|
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
|
|
|
+
|
|
|
+ if (ctx->cmd_bufs[i].mem_ranges) {
|
|
|
+ ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
[ctx->cmd_bufs_ext removeAllObjects];
|
|
|
@@ -2071,12 +2104,51 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-static int ggml_metal_encode_node(
|
|
|
- ggml_backend_t backend,
|
|
|
- int idx,
|
|
|
- int idx_end,
|
|
|
- id<MTLComputeCommandEncoder> encoder,
|
|
|
- struct ggml_metal_mem_pool * mem_pool) {
|
|
|
+struct ggml_metal_encode_context {
|
|
|
+ ggml_backend_t backend;
|
|
|
+
|
|
|
+ id<MTLComputeCommandEncoder> encoder;
|
|
|
+
|
|
|
+ struct ggml_metal_mem_pool * mem_pool;
|
|
|
+
|
|
|
+ struct ggml_mem_ranges * mem_ranges;
|
|
|
+};
|
|
|
+
|
|
|
+static bool ggml_metal_encode_concurrency_reset(struct ggml_metal_encode_context * ctx) {
|
|
|
+ if (!ctx->mem_ranges) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ [ctx->encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
|
|
+
|
|
|
+ ggml_mem_ranges_reset(ctx->mem_ranges);
|
|
|
+
|
|
|
+ return true;
|
|
|
+}
|
|
|
+
|
|
|
+static bool ggml_metal_encode_concurrency_check(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
|
|
|
+ if (!ctx->mem_ranges) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ return ggml_mem_ranges_check(ctx->mem_ranges, node);
|
|
|
+}
|
|
|
+
|
|
|
+static bool ggml_metal_encode_concurrency_add(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
|
|
|
+ if (!ctx->mem_ranges) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ return ggml_mem_ranges_add(ctx->mem_ranges, node);
|
|
|
+}
|
|
|
+
|
|
|
+static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
|
|
|
+ ggml_backend_t backend = ctx_enc->backend;
|
|
|
+
|
|
|
+ id<MTLComputeCommandEncoder> encoder = ctx_enc->encoder;
|
|
|
+
|
|
|
+ struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool;
|
|
|
+
|
|
|
struct ggml_backend_metal_context * ctx = backend->context;
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
|
|
|
@@ -2159,38 +2231,71 @@ static int ggml_metal_encode_node(
|
|
|
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
|
|
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
|
|
|
|
|
+ size_t offs_src[GGML_MAX_SRC];
|
|
|
+
|
|
|
+ id<MTLBuffer> id_src[GGML_MAX_SRC];
|
|
|
+
|
|
|
+ enum ggml_type srct[GGML_MAX_SRC];
|
|
|
+
|
|
|
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
|
+ offs_src[i] = 0;
|
|
|
+ id_src[i] = node->src[i] ? ggml_metal_get_buffer(node->src[i], &offs_src[i]) : nil;
|
|
|
+ srct[i] = node->src[i] ? node->src[i]->type : GGML_TYPE_COUNT;
|
|
|
+ }
|
|
|
+
|
|
|
+ // TODO: tmp shorthands - remove
|
|
|
+ size_t offs_src0 = offs_src[0];
|
|
|
+ size_t offs_src1 = offs_src[1];
|
|
|
+ size_t offs_src2 = offs_src[2];
|
|
|
+
|
|
|
+ id<MTLBuffer> id_src0 = id_src[0];
|
|
|
+ id<MTLBuffer> id_src1 = id_src[1];
|
|
|
+ id<MTLBuffer> id_src2 = id_src[2];
|
|
|
+
|
|
|
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
|
|
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
|
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
|
|
|
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
|
|
|
|
|
- size_t offs_src0 = 0;
|
|
|
- size_t offs_src1 = 0;
|
|
|
- size_t offs_src2 = 0;
|
|
|
- size_t offs_dst = 0;
|
|
|
+ size_t offs_dst = 0;
|
|
|
|
|
|
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
|
|
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
|
|
- id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
|
|
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
|
|
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
|
|
|
|
|
int n_fuse = 1;
|
|
|
|
|
|
-#if 0
|
|
|
- GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
|
|
- if (src0) {
|
|
|
- GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
|
|
|
- ggml_is_contiguous(src0), src0->name);
|
|
|
- }
|
|
|
- if (src1) {
|
|
|
- GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
|
|
- ggml_is_contiguous(src1), src1->name);
|
|
|
- }
|
|
|
- if (dst) {
|
|
|
- GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
|
|
- dst->name);
|
|
|
+ // check if the current node can run concurrently with other nodes before it
|
|
|
+ // the condition is that:
|
|
|
+ // - the current node cannot write to any previous src or dst ranges
|
|
|
+ // - the current node cannot read from any previous dst ranges
|
|
|
+ //
|
|
|
+ // if the condition is not satisfied, we put a memory barrier and clear all ranges
|
|
|
+ // otherwise, we add the new ranges to the encoding context and process the node concurrently
|
|
|
+ //
|
|
|
+ {
|
|
|
+ const bool is_concurrent = ggml_metal_encode_concurrency_check(ctx_enc, node);
|
|
|
+
|
|
|
+ if (!is_concurrent) {
|
|
|
+ ggml_metal_encode_concurrency_reset(ctx_enc);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (ctx_dev->debug_graph > 0) {
|
|
|
+ GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(dst->op), is_concurrent ? "(concurrent)" : "");
|
|
|
+ }
|
|
|
+ if (ctx_dev->debug_graph > 1) {
|
|
|
+ if (src0) {
|
|
|
+ GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
|
|
|
+ ggml_is_contiguous(src0), src0->name);
|
|
|
+ }
|
|
|
+ if (src1) {
|
|
|
+ GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
|
|
+ ggml_is_contiguous(src1), src1->name);
|
|
|
+ }
|
|
|
+ if (dst) {
|
|
|
+ GGML_LOG_DEBUG("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
|
|
+ dst->name);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
-#endif
|
|
|
|
|
|
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
|
|
|
|
@@ -2389,6 +2494,14 @@ static int ggml_metal_encode_node(
|
|
|
|
|
|
if (n_fuse > 1) {
|
|
|
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
|
|
+
|
|
|
+ for (int i = 1; i < n_fuse; ++i) {
|
|
|
+ if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
|
|
|
+ ggml_metal_encode_concurrency_reset(ctx_enc);
|
|
|
+
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
@@ -2533,6 +2646,8 @@ static int ggml_metal_encode_node(
|
|
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
+
|
|
|
+ ggml_metal_encode_concurrency_reset(ctx_enc);
|
|
|
}
|
|
|
|
|
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
|
|
@@ -3997,6 +4112,12 @@ static int ggml_metal_encode_node(
|
|
|
default: break;
|
|
|
}
|
|
|
|
|
|
+ // TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
|
|
|
+ // reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
|
|
|
+ // so we add this extra barrier to prevent the race.
|
|
|
+ // the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
|
|
|
+ ggml_metal_encode_concurrency_reset(ctx_enc);
|
|
|
+
|
|
|
// tokens per expert
|
|
|
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
|
|
|
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
|
|
@@ -4057,6 +4178,9 @@ static int ggml_metal_encode_node(
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
|
|
|
}
|
|
|
|
|
|
+ // this barrier is always needed because the next kernel has to wait for the id maps to be computed
|
|
|
+ ggml_metal_encode_concurrency_reset(ctx_enc);
|
|
|
+
|
|
|
{
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
|
@@ -4525,6 +4649,14 @@ static int ggml_metal_encode_node(
|
|
|
|
|
|
if (n_fuse > 1) {
|
|
|
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
|
|
+
|
|
|
+ for (int i = 1; i < n_fuse; ++i) {
|
|
|
+ if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
|
|
|
+ ggml_metal_encode_concurrency_reset(ctx_enc);
|
|
|
+
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline;
|
|
|
@@ -4668,7 +4800,6 @@ static int ggml_metal_encode_node(
|
|
|
} break;
|
|
|
case GGML_OP_ROPE:
|
|
|
{
|
|
|
-
|
|
|
// make sure we have one or more position id(ne10) per token(ne02)
|
|
|
GGML_ASSERT(ne10 % ne02 == 0);
|
|
|
GGML_ASSERT(ne10 >= ne02);
|
|
|
@@ -5427,6 +5558,10 @@ static int ggml_metal_encode_node(
|
|
|
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
|
|
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
|
|
|
|
|
|
+ // using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
|
|
|
+ // still, we assume that concurrent FA won't happen before we do the refactor
|
|
|
+ //ggml_metal_encode_concurrency_reset(ctx_enc);
|
|
|
+
|
|
|
const int32_t nrows = ne1*ne2*ne3;
|
|
|
|
|
|
// temp buffer for writing the results from each workgroup
|
|
|
@@ -5447,6 +5582,8 @@ static int ggml_metal_encode_node(
|
|
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
|
|
|
|
|
+ ggml_metal_encode_concurrency_reset(ctx_enc);
|
|
|
+
|
|
|
// reduce the results from the workgroups
|
|
|
{
|
|
|
ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
|
|
|
@@ -5677,7 +5814,7 @@ static int ggml_metal_encode_node(
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
|
|
} break;
|
|
|
- case GGML_OP_ARGMAX:
|
|
|
+ case GGML_OP_ARGMAX:
|
|
|
{
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
|
@@ -5709,6 +5846,19 @@ static int ggml_metal_encode_node(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ if (ctx_dev->debug_graph > 0) {
|
|
|
+ if (n_fuse > 1) {
|
|
|
+ GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // update the mem ranges in the encoding context
|
|
|
+ for (int i = 0; i < n_fuse; ++i) {
|
|
|
+ if (!ggml_metal_encode_concurrency_add(ctx_enc, nodes[i])) {
|
|
|
+ ggml_metal_encode_concurrency_reset(ctx_enc);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
return n_fuse;
|
|
|
}
|
|
|
|
|
|
@@ -5719,7 +5869,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
|
|
|
// number of nodes encoded by the main thread (empirically determined)
|
|
|
- const int n_main = 128;
|
|
|
+ const int n_main = 64;
|
|
|
|
|
|
// number of threads in addition to the main thread
|
|
|
const int n_cb = ctx->n_cb;
|
|
|
@@ -5774,6 +5924,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
// cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
|
|
|
// TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
|
|
|
// https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
|
|
|
+ // [TAG_MEM_POOL_REMOVE]
|
|
|
//id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
|
|
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
|
|
|
[cmd_buf retain];
|
|
|
@@ -6547,6 +6698,18 @@ static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend,
|
|
|
return ggml_metal_graph_compute(backend, cgraph);
|
|
|
}
|
|
|
|
|
|
+static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
|
|
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
+
|
|
|
+ //const int64_t t_start = ggml_time_us();
|
|
|
+
|
|
|
+ if (ctx_dev->use_graph_optimize) {
|
|
|
+ ggml_metal_graph_optimize(cgraph);
|
|
|
+ }
|
|
|
+
|
|
|
+ //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
|
|
|
+}
|
|
|
+
|
|
|
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
|
|
|
|
|
@@ -6573,12 +6736,25 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
|
|
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
|
|
|
|
|
- id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
|
|
- struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
|
|
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
|
|
+ struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
|
|
+ struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges;
|
|
|
|
|
|
ggml_metal_mem_pool_reset(mem_pool);
|
|
|
|
|
|
- id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
|
|
|
+ if (mem_ranges) {
|
|
|
+ ggml_mem_ranges_reset(mem_ranges);
|
|
|
+ }
|
|
|
+
|
|
|
+ id<MTLComputeCommandEncoder> encoder;
|
|
|
+
|
|
|
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
|
+
|
|
|
+ if (ctx_dev->use_concurrency) {
|
|
|
+ encoder = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
|
|
|
+ } else {
|
|
|
+ encoder = [cmd_buf computeCommandEncoder];
|
|
|
+ }
|
|
|
|
|
|
int node_start = 0;
|
|
|
int node_end = n_nodes_0;
|
|
|
@@ -6590,12 +6766,19 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
|
|
|
const bool should_capture = ctx->capture_next_compute;
|
|
|
|
|
|
+ struct ggml_metal_encode_context ctx_enc = {
|
|
|
+ /*.backend =*/ backend,
|
|
|
+ /*.encoder =*/ encoder,
|
|
|
+ /*.mem_pool =*/ mem_pool,
|
|
|
+ /*.mem_ranges =*/ mem_ranges,
|
|
|
+ };
|
|
|
+
|
|
|
for (int idx = node_start; idx < node_end;) {
|
|
|
if (should_capture) {
|
|
|
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
|
|
}
|
|
|
|
|
|
- const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
|
|
|
+ const int res = ggml_metal_encode_node(&ctx_enc, idx, node_end);
|
|
|
if (idx + res > node_end) {
|
|
|
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
|
|
|
"https://github.com/ggml-org/llama.cpp/pull/14849");
|
|
|
@@ -6638,7 +6821,7 @@ static struct ggml_backend_i ggml_backend_metal_i = {
|
|
|
// https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
|
|
|
/* .event_record = */ NULL,
|
|
|
/* .event_wait = */ NULL,
|
|
|
- /* .optimize_graph = */ NULL,
|
|
|
+ /* .optimize_graph = */ ggml_backend_metal_graph_optimize,
|
|
|
};
|
|
|
|
|
|
static ggml_guid_t ggml_backend_metal_guid(void) {
|