Просмотр исходного кода

metal : fix fusion across different encoders (#14849)

* metal : fix fusion across different encoders

ggml-ci

* cont : add assertion

ggml-ci
Georgi Gerganov 5 месяцев назад
Родитель
Сommit
065908cb09
1 измененных файлов с 10 добавлено и 3 удалено
  1. 10 3
      ggml/src/ggml-metal/ggml-metal.m

+ 10 - 3
ggml/src/ggml-metal/ggml-metal.m

@@ -1955,6 +1955,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
 static int ggml_metal_encode_node(
 static int ggml_metal_encode_node(
                         ggml_backend_t   backend,
                         ggml_backend_t   backend,
                                    int   idx,
                                    int   idx,
+                                   int   idx_end,
           id<MTLComputeCommandEncoder>   encoder,
           id<MTLComputeCommandEncoder>   encoder,
             struct ggml_metal_mem_pool * mem_pool) {
             struct ggml_metal_mem_pool * mem_pool) {
     struct ggml_backend_metal_context        * ctx     = backend->context;
     struct ggml_backend_metal_context        * ctx     = backend->context;
@@ -2181,7 +2182,9 @@ static int ggml_metal_encode_node(
                     size_t offs_fuse;
                     size_t offs_fuse;
                     id<MTLBuffer> id_fuse;
                     id<MTLBuffer> id_fuse;
 
 
-                    for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
+                    // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
+                    //       across splits. idx_end indicates the last node in the current split
+                    for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
                         if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
                         if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
                             break;
                             break;
                         }
                         }
@@ -4288,7 +4291,7 @@ static int ggml_metal_encode_node(
                     ops[1] = GGML_OP_MUL;
                     ops[1] = GGML_OP_MUL;
                     ops[2] = GGML_OP_ADD;
                     ops[2] = GGML_OP_ADD;
 
 
-                    for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
+                    for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
                         if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
                         if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
                             break;
                             break;
                         }
                         }
@@ -6271,7 +6274,11 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
                 [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
                 [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
             }
             }
 
 
-            const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
+            const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
+            if (idx + res > node_end) {
+                GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
+                        "https://github.com/ggml-org/llama.cpp/pull/14849");
+            }
 
 
             if (should_capture) {
             if (should_capture) {
                 [encoder popDebugGroup];
                 [encoder popDebugGroup];