瀏覽代碼

sync : ggml (new ops, tests, backend, etc.) (#4359)

* sync : ggml (part 1)

* sync : ggml (part 2, CUDA)

* sync : ggml (part 3, Metal)

* ggml : build fixes

ggml-ci

* cuda : restore lost changes

* cuda : restore lost changes (StableLM rope)

* cmake : enable separable compilation for CUDA

ggml-ci

* ggml-cuda : remove device side dequantize

* Revert "cmake : enable separable compilation for CUDA"

This reverts commit 09e35d04b1c4ca67f9685690160b35bc885a89ac.

* cuda : remove assert for rope

* tests : add test-backend-ops

* ggml : fix bug in ggml_concat

* ggml : restore `ggml_get_n_tasks()` logic in `ggml_graph_plan()`

* ci : try to fix macOS

* ggml-backend : remove backend self-registration

* ci : disable Metal for macOS cmake build

ggml-ci

* metal : fix "supports family" call

* metal : fix assert

* metal : print resource path

ggml-ci

---------

Co-authored-by: slaren <slarengh@gmail.com>
Georgi Gerganov 2 年之前
父節點
當前提交
fe680e3d10
共有 20 個文件被更改,包括 4067 次插入851 次删除
  1. 11 4
      .github/workflows/build.yml
  2. 1 0
      .gitignore
  3. 7 7
      CMakeLists.txt
  4. 5 1
      Makefile
  5. 42 7
      ggml-alloc.c
  6. 7 0
      ggml-alloc.h
  7. 46 21
      ggml-backend-impl.h
  8. 563 156
      ggml-backend.c
  9. 62 17
      ggml-backend.h
  10. 646 254
      ggml-cuda.cu
  11. 9 1
      ggml-cuda.h
  12. 1 1
      ggml-impl.h
  13. 6 0
      ggml-metal.h
  14. 479 155
      ggml-metal.m
  15. 431 121
      ggml-metal.metal
  16. 325 89
      ggml.c
  17. 49 4
      ggml.h
  18. 3 2
      scripts/sync-ggml.sh
  19. 17 11
      tests/CMakeLists.txt
  20. 1357 0
      tests/test-backend-ops.cpp

+ 11 - 4
.github/workflows/build.yml

@@ -143,6 +143,9 @@ jobs:
           cd build
           ctest --verbose
 
+  # TODO: build with LLAMA_NO_METAL because test-backend-ops fail on "Apple Paravirtual device" and I don't know
+  #       how to debug it.
+  #       ref: https://github.com/ggerganov/llama.cpp/actions/runs/7131777249/job/19420981052#step:5:1124
   macOS-latest-make:
     runs-on: macos-latest
 
@@ -160,14 +163,18 @@ jobs:
       - name: Build
         id: make_build
         run: |
-          make -j $(sysctl -n hw.logicalcpu)
+          LLAMA_NO_METAL=1 make -j $(sysctl -n hw.logicalcpu)
 
       - name: Test
         id: make_test
         run: |
-          make tests -j $(sysctl -n hw.logicalcpu)
-          make test -j $(sysctl -n hw.logicalcpu)
+          LLAMA_NO_METAL=1 make tests -j $(sysctl -n hw.logicalcpu)
+          LLAMA_NO_METAL=1 make test  -j $(sysctl -n hw.logicalcpu)
 
+  # TODO: build with LLAMA_METAL=OFF because test-backend-ops fail on "Apple Paravirtual device" and I don't know
+  #       how to debug it.
+  #       ref: https://github.com/ggerganov/llama.cpp/actions/runs/7132125951/job/19422043567?pr=4359#step:5:6584
+  #       would be great if we fix these
   macOS-latest-cmake:
     runs-on: macos-latest
 
@@ -188,7 +195,7 @@ jobs:
           sysctl -a
           mkdir build
           cd build
-          cmake ..
+          cmake -DLLAMA_METAL=OFF ..
           cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
 
       - name: Test

+ 1 - 0
.gitignore

@@ -101,3 +101,4 @@ poetry.toml
 /tests/test-tokenizer-1-llama
 /tests/test-tokenizer-1-bpe
 /tests/test-rope
+/tests/test-backend-ops

+ 7 - 7
CMakeLists.txt

@@ -97,9 +97,9 @@ option(LLAMA_METAL_NDEBUG                    "llama: disable Metal debugging"
 option(LLAMA_MPI                             "llama: use MPI"                                   OFF)
 option(LLAMA_QKK_64                          "llama: use super-block size of 64 for k-quants"   OFF)
 
-option(LLAMA_BUILD_TESTS                "llama: build tests"    ${LLAMA_STANDALONE})
-option(LLAMA_BUILD_EXAMPLES             "llama: build examples" ${LLAMA_STANDALONE})
-option(LLAMA_BUILD_SERVER               "llama: build server example"                           ON)
+option(LLAMA_BUILD_TESTS                     "llama: build tests"    ${LLAMA_STANDALONE})
+option(LLAMA_BUILD_EXAMPLES                  "llama: build examples" ${LLAMA_STANDALONE})
+option(LLAMA_BUILD_SERVER                    "llama: build server example"                      ON)
 
 # Required for relocatable CMake package
 include(${CMAKE_CURRENT_SOURCE_DIR}/scripts/build-info.cmake)
@@ -662,11 +662,11 @@ add_library(ggml OBJECT
             ggml-backend.h
             ggml-quants.c
             ggml-quants.h
-            ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
+            ${GGML_SOURCES_CUDA}   ${GGML_HEADERS_CUDA}
             ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
-            ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
-            ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI}
-            ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA}
+            ${GGML_SOURCES_METAL}  ${GGML_HEADERS_METAL}
+            ${GGML_SOURCES_MPI}    ${GGML_HEADERS_MPI}
+            ${GGML_SOURCES_EXTRA}  ${GGML_HEADERS_EXTRA}
             )
 
 target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})

+ 5 - 1
Makefile

@@ -8,7 +8,8 @@ BUILD_TARGETS = \
 TEST_TARGETS = \
 	tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \
 	tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama          \
-	tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope
+	tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope      \
+	tests/test-backend-ops
 
 # Code coverage output files
 COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report
@@ -746,3 +747,6 @@ tests/test-rope: tests/test-rope.cpp ggml.o $(OBJS)
 
 tests/test-c.o: tests/test-c.c llama.h
 	$(CC) $(CFLAGS) -c $(filter-out %.h,$^) -o $@
+
+tests/test-backend-ops: tests/test-backend-ops.cpp ggml.o $(OBJS)
+	$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

+ 42 - 7
ggml-alloc.c

@@ -168,10 +168,6 @@ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor *
     size = aligned_offset(NULL, size, alloc->alignment);
     AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
 
-    if (!alloc->measure) {
-        ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
-    }
-
 #ifdef GGML_ALLOCATOR_DEBUG
     remove_allocated_tensor(alloc, tensor);
 #endif
@@ -237,7 +233,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
 }
 
 ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
-    struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
+    struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
 
     ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
 
@@ -449,7 +445,6 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
 static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
     ggml_tallocr_t alloc = node_tallocr(galloc, view);
 
-    //printf("init_view: %s from src %s\n", view->name, view->view_src->name);
     GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
     if (update_backend) {
         view->backend = view->view_src->backend;
@@ -459,7 +454,7 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd
 
     // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
     // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
-    assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
+    assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
 
     if (!alloc->measure) {
         ggml_backend_buffer_init_tensor(alloc->buffer, view);
@@ -765,3 +760,43 @@ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
 size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
     return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
 }
+
+// utils
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
+
+    size_t alignment = ggml_backend_buft_get_alignment(buft);
+
+    size_t nbytes = 0;
+    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        if (t->data == NULL && t->view_src == NULL) {
+            nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
+        }
+    }
+
+    if (nbytes == 0) {
+        fprintf(stderr, "%s: no tensors to allocate\n", __func__);
+        return NULL;
+    }
+
+    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
+    ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
+
+    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        if (t->data == NULL) {
+            if (t->view_src == NULL) {
+                ggml_tallocr_alloc(tallocr, t);
+            } else {
+                ggml_backend_view_init(buffer, t);
+            }
+        }
+    }
+
+    ggml_tallocr_free(tallocr);
+
+    return buffer;
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
+    return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
+}

+ 7 - 0
ggml-alloc.h

@@ -8,6 +8,7 @@ extern "C" {
 
 struct ggml_backend;
 struct ggml_backend_buffer;
+struct ggml_backend_buffer_type;
 
 //
 // Legacy API
@@ -80,6 +81,12 @@ GGML_API void   ggml_gallocr_alloc_graph_n(
                     struct ggml_hash_set hash_set,
                     ggml_tallocr_t * hash_node_talloc);
 
+
+// Utils
+// Create a buffer and allocate all the tensors in a ggml_context
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
+
 #ifdef  __cplusplus
 }
 #endif

+ 46 - 21
ggml-backend-impl.h

@@ -12,31 +12,50 @@ extern "C" {
     // Backend buffer
     //
 
+    // buffer type
+    typedef void * ggml_backend_buffer_type_context_t;
+
+    struct ggml_backend_buffer_type_i {
+        ggml_backend_buffer_t (*alloc_buffer)    (ggml_backend_buffer_type_t buft, size_t size);
+        size_t                (*get_alignment)   (ggml_backend_buffer_type_t buft); // tensor alignment
+        size_t                (*get_alloc_size)  (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
+        bool                  (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
+    };
+
+    struct ggml_backend_buffer_type {
+        struct ggml_backend_buffer_type_i  iface;
+        ggml_backend_buffer_type_context_t context;
+    };
+
+    // buffer
     typedef void * ggml_backend_buffer_context_t;
 
     struct ggml_backend_buffer_i {
-        void   (*free_buffer)   (ggml_backend_buffer_t buffer);
-        void * (*get_base)      (ggml_backend_buffer_t buffer); // get base pointer
-        size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
-        void   (*init_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
-        void   (*free_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
+        void     (*free_buffer)(ggml_backend_buffer_t buffer);
+        //void     (*reset)      (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
+        void *   (*get_base)   (ggml_backend_buffer_t buffer);
+        void     (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+        void     (*set_tensor) (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void     (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        // (optional) copy tensor between different buffer-type, allow for single-copy tranfers
+        void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
+        void (*cpy_tensor_to)  (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
     };
 
     struct ggml_backend_buffer {
-        struct ggml_backend_buffer_i iface;
-
-        ggml_backend_t                backend;
+        struct ggml_backend_buffer_i  iface;
+        ggml_backend_buffer_type_t    buft;
         ggml_backend_buffer_context_t context;
-
         size_t size;
     };
 
-    GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
-            struct ggml_backend                  * backend,
+    ggml_backend_buffer_t ggml_backend_buffer_init(
+                   ggml_backend_buffer_type_t      buft,
             struct ggml_backend_buffer_i           iface,
                    ggml_backend_buffer_context_t   context,
                    size_t                          size);
 
+
     //
     // Backend
     //
@@ -49,20 +68,17 @@ extern "C" {
         void (*free)(ggml_backend_t backend);
 
         // buffer allocation
-        ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
+        ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
 
-        // get buffer alignment
-        size_t (*get_alignment)(ggml_backend_t backend);
-
-        // tensor data access
-        // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
+        // (optional) asynchroneous tensor data access
         void (*set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
         void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-        void (*synchronize)     (ggml_backend_t backend);
 
-        // (optional) copy tensor between different backends, allow for single-copy tranfers
-        void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
-        void (*cpy_tensor_to)  (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+        // (optional) asynchroneous tensor copy
+        void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+        void (*cpy_tensor_to_async)  (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+
+        void (*synchronize)     (ggml_backend_t backend);
 
         // compute graph with a plan
         ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
@@ -82,6 +98,15 @@ extern "C" {
         ggml_backend_context_t context;
     };
 
+
+    //
+    // Backend registry
+    //
+
+    typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
+
+    void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
+
 #ifdef  __cplusplus
 }
 #endif

+ 563 - 156
ggml-backend.c

@@ -9,14 +9,36 @@
 #include <stdlib.h>
 #include <string.h>
 
-#define UNUSED GGML_UNUSED
 
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
+
+// backend buffer type
+
+ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    return buft->iface.alloc_buffer(buft, size);
+}
+
+size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_alignment(buft);
+}
+
+size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
+    // get_alloc_size is optional, defaults to ggml_nbytes
+    if (buft->iface.get_alloc_size) {
+        return buft->iface.get_alloc_size(buft, tensor);
+    }
+    return ggml_nbytes(tensor);
+}
+
+bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return buft->iface.supports_backend(buft, backend);
+}
+
 // backend buffer
 
 ggml_backend_buffer_t ggml_backend_buffer_init(
-        struct ggml_backend                  * backend,
+               ggml_backend_buffer_type_t      buft,
         struct ggml_backend_buffer_i           iface,
                ggml_backend_buffer_context_t   context,
                size_t                          size) {
@@ -26,7 +48,7 @@ ggml_backend_buffer_t ggml_backend_buffer_init(
 
     (*buffer) = (struct ggml_backend_buffer) {
         /* .interface = */ iface,
-        /* .backend   = */ backend,
+        /* .buft      = */ buft,
         /* .context   = */ context,
         /* .size      = */ size,
     };
@@ -45,10 +67,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
     free(buffer);
 }
 
-size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {
-    return ggml_backend_get_alignment(buffer->backend);
-}
-
 size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
     return buffer->size;
 }
@@ -61,14 +79,6 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
     return base;
 }
 
-size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
-    // get_alloc_size is optional, defaults to ggml_nbytes
-    if (buffer->iface.get_alloc_size) {
-        return buffer->iface.get_alloc_size(buffer, tensor);
-    }
-    return ggml_nbytes(tensor);
-}
-
 void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
     // init_tensor is optional
     if (buffer->iface.init_tensor) {
@@ -76,19 +86,20 @@ void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_t
     }
 }
 
-void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
-    // free_tensor is optional
-    if (buffer->iface.free_tensor) {
-        buffer->iface.free_tensor(buffer, tensor);
-    }
+size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) {
+    return ggml_backend_buft_get_alignment(ggml_backend_buffer_type(buffer));
 }
 
-// backend
+size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type(buffer), tensor);
+}
 
-ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) {
-    return tensor->buffer ? tensor->buffer->backend : NULL;
+ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer) {
+    return buffer->buft;
 }
 
+// backend
+
 const char * ggml_backend_name(ggml_backend_t backend) {
     if (backend == NULL) {
         return "NULL";
@@ -104,43 +115,53 @@ void ggml_backend_free(ggml_backend_t backend) {
     backend->iface.free(backend);
 }
 
+ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
+    return backend->iface.get_default_buffer_type(backend);
+}
+
 ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
-    return backend->iface.alloc_buffer(backend, size);
+    return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size);
 }
 
 size_t ggml_backend_get_alignment(ggml_backend_t backend) {
-    return backend->iface.get_alignment(backend);
+    return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend));
 }
 
-void ggml_backend_tensor_set_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
+void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+    backend->iface.set_tensor_async(backend, tensor, data, offset, size);
 }
 
-void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
+void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+
+    backend->iface.get_tensor_async(backend, tensor, data, offset, size);
 }
 
 void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    ggml_backend_t backend = ggml_get_backend(tensor);
-
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-    GGML_ASSERT(backend != NULL && "tensor backend not set");
+    GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
 
-    backend->iface.set_tensor_async(backend, tensor, data, offset, size);
-    backend->iface.synchronize(backend);
+    tensor->buffer->iface.set_tensor(tensor->buffer, tensor, data, offset, size);
 }
 
 void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    ggml_backend_t backend = ggml_get_backend(tensor);
-
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-    GGML_ASSERT(backend != NULL && "tensor backend not set");
+    GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
 
-    backend->iface.get_tensor_async(backend, tensor, data, offset, size);
-    backend->iface.synchronize(backend);
+    tensor->buffer->iface.get_tensor(tensor->buffer, tensor, data, offset, size);
 }
 
 void ggml_backend_synchronize(ggml_backend_t backend) {
+    if (backend->iface.synchronize == NULL) {
+        return;
+    }
+
     backend->iface.synchronize(backend);
 }
 
@@ -154,10 +175,16 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
 
 void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
     backend->iface.graph_plan_compute(backend, plan);
+
+    // TODO: optional sync
+    ggml_backend_synchronize(backend);
 }
 
 void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
     backend->iface.graph_compute(backend, cgraph);
+
+    // TODO: optional sync
+    ggml_backend_synchronize(backend);
 }
 
 bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -194,14 +221,15 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
 
     // TODO: allow backends to support copy to/from same backend
 
-    if (ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) {
-        ggml_get_backend(dst)->iface.cpy_tensor_from(ggml_get_backend(dst)->context, src, dst);
-    } else if (ggml_get_backend(src)->iface.cpy_tensor_to != NULL) {
-        ggml_get_backend(src)->iface.cpy_tensor_to(ggml_get_backend(src)->context, src, dst);
+    if (dst->buffer->iface.cpy_tensor_from != NULL) {
+        dst->buffer->iface.cpy_tensor_from(dst->buffer, src, dst);
+    } else if (src->buffer->iface.cpy_tensor_to != NULL) {
+        src->buffer->iface.cpy_tensor_to(src->buffer, src, dst);
     } else {
         // shouldn't be hit when copying from/to CPU
         #ifndef NDEBUG
-        fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", ggml_backend_name(src->buffer->backend), ggml_backend_name(dst->buffer->backend));
+        fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to "
+                        "are implemented for %s and %s, falling back to get/set\n", src->name, dst->name);
         #endif
         size_t nbytes = ggml_nbytes(src);
         void * data = malloc(nbytes);
@@ -211,101 +239,259 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
     }
 }
 
-// backend CPU
+// backend registry
 
-struct ggml_backend_cpu_context {
-    int n_threads;
-    void * work_data;
-    size_t work_size;
+#define GGML_MAX_BACKENDS_REG 16
+
+struct ggml_backend_reg {
+    char name[128];
+    ggml_backend_init_fn init_fn;
+    ggml_backend_buffer_type_t default_buffer_type;
+    void * user_data;
 };
 
-static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
-    return "CPU";
+static struct ggml_backend_reg ggml_backend_registry[GGML_MAX_BACKENDS_REG];
+static size_t ggml_backend_registry_count = 0;
+
+static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data);
+
+static void ggml_backend_registry_init(void) {
+    static bool initialized = false;
+
+    if (initialized) {
+        return;
+    }
+
+    initialized = true;
 
-    UNUSED(backend);
+    ggml_backend_register("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL);
+
+    // add forward decls here to avoid including the backend headers
+#ifdef GGML_USE_CUBLAS
+    extern void ggml_backend_cuda_reg_devices(void);
+    ggml_backend_cuda_reg_devices();
+#endif
+
+#ifdef GGML_USE_METAL
+    extern ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data);
+    extern ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
+    ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL);
+#endif
 }
 
-static void ggml_backend_cpu_free(ggml_backend_t backend) {
-    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
-    free(cpu_ctx->work_data);
-    free(cpu_ctx);
-    free(backend);
+void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
+    GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG);
+
+    int id = ggml_backend_registry_count;
+
+    ggml_backend_registry[id] = (struct ggml_backend_reg) {
+        /* .name                = */ {0},
+        /* .fn                  = */ init_fn,
+        /* .default_buffer_type = */ default_buffer_type,
+        /* .user_data           = */ user_data,
+    };
+
+    snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name);
+
+#ifndef NDEBUG
+    fprintf(stderr, "%s: registered backend %s\n", __func__, name);
+#endif
+
+    ggml_backend_registry_count++;
+}
+
+size_t ggml_backend_reg_get_count(void) {
+    ggml_backend_registry_init();
+
+    return ggml_backend_registry_count;
+}
+
+size_t ggml_backend_reg_find_by_name(const char * name) {
+    ggml_backend_registry_init();
+
+    for (size_t i = 0; i < ggml_backend_registry_count; i++) {
+        // TODO: case insensitive in a portable way
+        if (strcmp(ggml_backend_registry[i].name, name) == 0) {
+            return i;
+        }
+    }
+    return SIZE_MAX;
+}
+
+// init from backend:params string
+ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) {
+    ggml_backend_registry_init();
+
+    const char * params = strchr(backend_str, ':');
+    char backend_name[128];
+    if (params == NULL) {
+        strcpy(backend_name, backend_str);
+        params = "";
+    } else {
+        strncpy(backend_name, backend_str, params - backend_str);
+        backend_name[params - backend_str] = '\0';
+        params++;
+    }
+
+    size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
+    if (backend_i == SIZE_MAX) {
+        fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
+        return NULL;
+    }
+
+    return ggml_backend_reg_init_backend(backend_i, params);
+}
+
+const char * ggml_backend_reg_get_name(size_t i) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].name;
+}
+
+ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data);
+}
+
+ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].default_buffer_type;
+}
+
+ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size);
 }
 
+// backend CPU
+
 static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
     return (void *)buffer->context;
 }
 
 static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
     free(buffer->context);
-    UNUSED(buffer);
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy((char *)tensor->data + offset, data, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+
+    GGML_UNUSED(buffer);
 }
 
 static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
-    /* .free_buffer    = */ ggml_backend_cpu_buffer_free_buffer,
-    /* .get_base       = */ ggml_backend_cpu_buffer_get_base,
-    /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
-    /* .init_tensor    = */ NULL, // no initialization required
-    /* .free_tensor    = */ NULL, // no cleanup required
+    /* .free_buffer     = */ ggml_backend_cpu_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
+    /* .cpy_tensor_to   = */ ggml_backend_cpu_buffer_cpy_tensor_to,
 };
 
 // for buffers from ptr, free is not called
 static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
-    /* .free_buffer    = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
-    /* .get_base       = */ ggml_backend_cpu_buffer_get_base,
-    /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
-    /* .init_tensor    = */ NULL,
-    /* .free_tensor    = */ NULL,
+    /* .free_buffer     = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
+    /* .cpy_tensor_to   = */ ggml_backend_cpu_buffer_cpy_tensor_to,
 };
 
 static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
 
-static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backend, size_t size) {
+static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
     size += TENSOR_ALIGNMENT;   // malloc may return an address that is not aligned
     void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
 
     GGML_ASSERT(data != NULL && "failed to allocate buffer");
 
-    return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size);
+    return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
 }
 
-static size_t ggml_backend_cpu_get_alignment(ggml_backend_t backend) {
+static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     return TENSOR_ALIGNMENT;
-    UNUSED(backend);
-}
 
-static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_UNUSED(buft);
+}
 
-    memcpy((char *)tensor->data + offset, data, size);
+static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return ggml_backend_is_cpu(backend);
 
-    UNUSED(backend);
+    GGML_UNUSED(buft);
 }
 
-static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-
-    memcpy(data, (const char *)tensor->data + offset, size);
+ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_cpu = {
+        /* .iface = */ {
+            /* .alloc_buffer     = */ ggml_backend_cpu_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend,
+        },
+        /* .context = */ NULL,
+    };
 
-    UNUSED(backend);
+    return &ggml_backend_buffer_type_cpu;
 }
 
-static void ggml_backend_cpu_synchronize(ggml_backend_t backend) {
-    UNUSED(backend);
-}
+struct ggml_backend_cpu_context {
+    int n_threads;
+    void * work_data;
+    size_t work_size;
+};
 
-static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
+    return "CPU";
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
-static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+static void ggml_backend_cpu_free(ggml_backend_t backend) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+    free(cpu_ctx->work_data);
+    free(cpu_ctx);
+    free(backend);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_cpu_buffer_type();
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
 struct ggml_backend_plan_cpu {
@@ -334,7 +520,7 @@ static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backen
     free(cpu_plan->cplan.work_data);
     free(cpu_plan);
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
 static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
@@ -342,7 +528,7 @@ static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_bac
 
     ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
 static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
@@ -363,25 +549,25 @@ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_c
 
 static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
     return true;
-    UNUSED(backend);
-    UNUSED(op);
+
+    GGML_UNUSED(backend);
+    GGML_UNUSED(op);
 }
 
 static struct ggml_backend_i cpu_backend_i = {
-    /* .get_name            = */ ggml_backend_cpu_name,
-    /* .free                = */ ggml_backend_cpu_free,
-    /* .alloc_buffer        = */ ggml_backend_cpu_alloc_buffer,
-    /* .get_alignment       = */ ggml_backend_cpu_get_alignment,
-    /* .set_tensor_async    = */ ggml_backend_cpu_set_tensor_async,
-    /* .get_tensor_async    = */ ggml_backend_cpu_get_tensor_async,
-    /* .synchronize         = */ ggml_backend_cpu_synchronize,
-    /* .cpy_tensor_from     = */ ggml_backend_cpu_cpy_tensor_from,
-    /* .cpy_tensor_to       = */ ggml_backend_cpu_cpy_tensor_to,
-    /* .graph_plan_create   = */ ggml_backend_cpu_graph_plan_create,
-    /* .graph_plan_free     = */ ggml_backend_cpu_graph_plan_free,
-    /* .graph_plan_compute  = */ ggml_backend_cpu_graph_plan_compute,
-    /* .graph_compute       = */ ggml_backend_cpu_graph_compute,
-    /* .supports_op         = */ ggml_backend_cpu_supports_op,
+    /* .get_name                = */ ggml_backend_cpu_name,
+    /* .free                    = */ ggml_backend_cpu_free,
+    /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_from_async   = */ NULL,
+    /* .cpy_tensor_to_async     = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ ggml_backend_cpu_graph_plan_create,
+    /* .graph_plan_free         = */ ggml_backend_cpu_graph_plan_free,
+    /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,
+    /* .graph_compute           = */ ggml_backend_cpu_graph_compute,
+    /* .supports_op             = */ ggml_backend_cpu_supports_op,
 };
 
 ggml_backend_t ggml_backend_cpu_init(void) {
@@ -411,10 +597,18 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
     ctx->n_threads = n_threads;
 }
 
-ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) {
-    return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size);
+ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
+    return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
+}
+
+static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) {
+    return ggml_backend_cpu_init();
+
+    GGML_UNUSED(params);
+    GGML_UNUSED(user_data);
 }
 
+
 // scheduler
 
 #define GGML_MAX_BACKENDS 4
@@ -427,7 +621,7 @@ struct ggml_backend_sched_split {
     int i_end;
     struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
     int n_inputs;
-    struct ggml_cgraph * graph;
+    struct ggml_cgraph graph;
 };
 
 struct ggml_backend_sched {
@@ -453,7 +647,7 @@ struct ggml_backend_sched {
     #else
     __attribute__((aligned(GGML_MEM_ALIGN)))
     #endif
-    char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + GGML_MAX_SPLITS*sizeof(struct ggml_cgraph)];
+    char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
 };
 
 #define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
@@ -482,23 +676,57 @@ static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr)
     return INT_MAX;
 }
 
+static ggml_backend_t get_buffer_backend(ggml_backend_sched_t sched, ggml_backend_buffer_t buffer) {
+    if (buffer == NULL) {
+        return NULL;
+    }
+    // find highest prio backend that supports the buffer type
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
+            return sched->backends[i];
+        }
+    }
+    GGML_ASSERT(false && "tensor buffer type not supported by any backend");
+}
+
+static ggml_backend_t get_allocr_backend(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
+    if (allocr == NULL) {
+        return NULL;
+    }
+    // find highest prio backend that supports the buffer type
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (sched->tallocs[i] == allocr) {
+            return sched->backends[i];
+        }
+    }
+    GGML_UNREACHABLE();
+}
+
+#if 0
+static char causes[GGML_DEFAULT_GRAPH_SIZE*8 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
+#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
+#define GET_CAUSE(node) causes[hash_id(node)]
+#else
+#define SET_CAUSE(node, ...)
+#define GET_CAUSE(node) ""
+#endif
+
 // returns the backend that should be used for the node based on the current locations
-char causes[GGML_DEFAULT_GRAPH_SIZE*4 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
 static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
     // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
     // ie. kv cache updates
     // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
     // dst
-    ggml_backend_t cur_backend = ggml_get_backend(node);
+    ggml_backend_t cur_backend = get_buffer_backend(sched, node->buffer);
     if (cur_backend != NULL) {
-        sprintf(causes[hash_id(node)], "1.dst");
+        SET_CAUSE(node, "1.dst");
         return cur_backend;
     }
 
     // view_src
-    if (node->view_src != NULL && ggml_get_backend(node->view_src) != NULL) {
-        sprintf(causes[hash_id(node)], "1.vsrc");
-        return ggml_get_backend(node->view_src);
+    if (node->view_src != NULL && get_buffer_backend(sched, node->view_src->buffer) != NULL) {
+        SET_CAUSE(node, "1.vsrc");
+        return get_buffer_backend(sched, node->view_src->buffer);
     }
 
     // src
@@ -510,7 +738,7 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
         if (src == NULL) {
             break;
         }
-        ggml_backend_t src_backend = ggml_get_backend(src);
+        ggml_backend_t src_backend = get_buffer_backend(sched, src->buffer);
         if (src_backend != NULL) {
             int src_prio = sched_backend_prio(sched, src_backend);
             size_t src_size = ggml_nbytes(src);
@@ -518,7 +746,7 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
                 cur_prio = src_prio;
                 cur_size = src_size;
                 cur_backend = src_backend;
-                sprintf(causes[hash_id(node)], "1.src%d", i);
+                SET_CAUSE(node, "1.src%d", i);
             }
         }
     }
@@ -539,10 +767,12 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra
     int cur_split = 0;
     for (int i = 0; i < graph->n_nodes; i++) {
         if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
-            ggml_backend_t split_backend = ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend;
-            fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs);
+            ggml_backend_t split_backend = get_allocr_backend(sched, sched->splits[cur_split].tallocr);
+            fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
+                sched->splits[cur_split].n_inputs);
             for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
-                fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
+                fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
+                    fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
             }
             fprintf(stderr, "\n");
             cur_split++;
@@ -552,16 +782,18 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra
             continue;
         }
         ggml_tallocr_t node_allocr = node_allocr(node);
-        ggml_backend_t node_backend = node_allocr ? ggml_tallocr_get_buffer(node_allocr)->backend : NULL;
-        fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]);
+        ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME:
+        fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name,
+            fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node));
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (src == NULL) {
                 break;
             }
             ggml_tallocr_t src_allocr = node_allocr(src);
-            ggml_backend_t src_backend = src_allocr ? ggml_tallocr_get_buffer(src_allocr)->backend : NULL;
-            fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]);
+            ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL;
+            fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name,
+                fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
         }
         fprintf(stderr, "\n");
     }
@@ -587,9 +819,9 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
     sched->n_splits = 0;
 
     struct ggml_init_params params = {
-        /*.mem_size =   */ sizeof(sched->context_buffer),
-        /*.mem_buffer = */ sched->context_buffer,
-        /*.no_alloc =   */ true
+        /* .mem_size =   */ sizeof(sched->context_buffer),
+        /* .mem_buffer = */ sched->context_buffer,
+        /* .no_alloc =   */ true
     };
 
     if (sched->ctx != NULL) {
@@ -605,9 +837,9 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
             // do not overwrite user assignments
             continue;
         }
-        ggml_backend_t leaf_backend = ggml_get_backend(leaf);
+        ggml_backend_t leaf_backend = get_buffer_backend(sched, leaf->buffer);
         if (leaf_backend == NULL && leaf->view_src != NULL) {
-            leaf_backend = ggml_get_backend(leaf->view_src);
+            leaf_backend = get_buffer_backend(sched, leaf->view_src->buffer);
         }
         if (leaf_backend != NULL) {
             node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend);
@@ -649,7 +881,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
                         cur_prio = src_prio;
                         cur_size = src_size;
                         node_allocr = src_allocr;
-                        sprintf(causes[hash_id(node)], "2.src%d", j);
+                        SET_CAUSE(node, "2.src%d", j);
                     }
                 }
             }
@@ -733,7 +965,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
                     struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
                     sched->node_copies[id][cur_backend_id] = tensor_copy;
                     node_allocr(tensor_copy) = cur_allocr;
-                    ggml_backend_t backend = ggml_tallocr_get_buffer(cur_allocr)->backend;
+                    ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
                     ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
                 }
                 node->src[j] = sched->node_copies[id][cur_backend_id];
@@ -761,8 +993,8 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
             ggml_tallocr_t src_allocr = node_allocr(src);
             if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
                 fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
-                    node->name, node_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL",
-                    j, src->name, src_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL");
+                    node->name, node_allocr ? ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
+                    j, src->name, src_allocr ? ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL");
             }
         }
     }
@@ -773,7 +1005,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
     struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
     for (int i = 0; i < sched->n_splits; i++) {
         struct ggml_backend_sched_split * split = &sched->splits[i];
-        split->graph = ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end);
+        split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
 
         // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
         for (int j = 0; j < split->n_inputs; j++) {
@@ -806,31 +1038,29 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
 
     for (int i = 0; i < sched->n_splits; i++) {
         struct ggml_backend_sched_split * split = &splits[i];
-        ggml_backend_t split_backend = ggml_tallocr_get_buffer(split->tallocr)->backend;
+        ggml_backend_t split_backend = get_allocr_backend(sched, split->tallocr);
         int split_backend_id = sched_backend_prio(sched, split_backend);
 
         // copy the input tensors to the split backend
         uint64_t copy_start_us = ggml_time_us();
         for (int j = 0; j < split->n_inputs; j++) {
-            struct ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)];
-            if (split->inputs[j]->buffer == NULL) {
-                if (split->inputs[j]->view_src == NULL) {
-                    fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name);
+            struct ggml_tensor * input = split->inputs[j];
+            struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_backend_prio(sched, split_backend)];
+            if (input->buffer == NULL) {
+                if (input->view_src == NULL) {
+                    fprintf(stderr, "input %s has no buffer and no view_src\n", input->name);
                     exit(1);
                 }
-                struct ggml_tensor * view = split->inputs[j];
-                view->backend = view->view_src->backend;
-                view->buffer  = view->view_src->buffer;
-                view->data    = (char *)view->view_src->data + view->view_offs;
-                ggml_backend_buffer_init_tensor(ggml_backend_sched_get_buffer(sched, view->buffer->backend), view);
+                // FIXME: may need to use the sched buffer instead
+                ggml_backend_view_init(input->view_src->buffer, input);
             }
             if (input_cpy->buffer == NULL) {
                 fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
                 exit(1);
             }
-            GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend);
-            GGML_ASSERT(input_cpy->buffer->backend == split_backend);
-            ggml_backend_tensor_copy(split->inputs[j], input_cpy);
+            //GGML_ASSERT(input->buffer->backend != input_cpy->buffer->backend);
+            //GGML_ASSERT(input_cpy->buffer->backend == split_backend);
+            ggml_backend_tensor_copy(input, input_cpy);
         }
         // ggml_backend_synchronize(split_backend);
         int64_t copy_end_us = ggml_time_us();
@@ -843,7 +1073,7 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
 #endif
 
         uint64_t compute_start_us = ggml_time_us();
-        ggml_backend_graph_compute(split_backend, split->graph);
+        ggml_backend_graph_compute(split_backend, &split->graph);
         // ggml_backend_synchronize(split_backend);
         uint64_t compute_end_us = ggml_time_us();
         compute_us[split_backend_id] += compute_end_us - compute_start_us;
@@ -872,8 +1102,6 @@ ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_bac
     struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
     memset(sched, 0, sizeof(struct ggml_backend_sched));
 
-    fprintf(stderr, "ggml_backend_sched size: %lu KB\n", sizeof(struct ggml_backend_sched)/1024);
-
     sched->n_backends = n_backends;
     for (int i = 0; i < n_backends; i++) {
         sched->backends[i] = backends[i];
@@ -948,3 +1176,182 @@ void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml
     GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
     node_allocr(node) = sched->tallocs[backend_index];
 }
+
+// utils
+void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->data == NULL);
+    GGML_ASSERT(tensor->view_src != NULL);
+    GGML_ASSERT(tensor->view_src->buffer != NULL);
+    GGML_ASSERT(tensor->view_src->data != NULL);
+
+    tensor->buffer = buffer;
+    tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
+    tensor->backend = tensor->view_src->backend;
+    ggml_backend_buffer_init_tensor(buffer, tensor);
+}
+
+void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->data == NULL);
+    GGML_ASSERT(tensor->view_src == NULL);
+    GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
+    GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
+                (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
+
+    tensor->buffer = buffer;
+    tensor->data = addr;
+    ggml_backend_buffer_init_tensor(buffer, tensor);
+}
+
+static struct ggml_tensor * graph_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
+    struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) {
+
+    GGML_ASSERT(src != NULL);
+    GGML_ASSERT(src->data && "graph must be allocated");
+
+    size_t id = ggml_hash_insert(hash_set, src);
+    if (id == GGML_HASHTABLE_ALREADY_EXISTS) {
+        return node_copies[ggml_hash_find(hash_set, src)];
+    }
+
+    struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
+    if (src->view_src != NULL) {
+        dst->view_src = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
+        dst->view_offs = src->view_offs;
+    }
+    dst->op = src->op;
+    memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
+    ggml_set_name(dst, src->name);
+
+    // copy src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            break;
+        }
+        dst->src[i] = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
+    }
+
+    node_copies[id] = dst;
+    return dst;
+}
+
+static void graph_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
+    size_t id = ggml_hash_find(hash_set, src);
+    if (node_init[id]) {
+        return;
+    }
+    node_init[id] = true;
+
+    struct ggml_tensor * dst = node_copies[id];
+    if (dst->view_src != NULL) {
+        ggml_backend_view_init(dst->view_src->buffer, dst);
+    }
+    else {
+        ggml_backend_tensor_copy(src, dst);
+    }
+
+    // init src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            break;
+        }
+        graph_init_tensor(hash_set, node_copies, node_init, s);
+    }
+}
+
+struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
+    struct ggml_hash_set hash_set = {
+        /* .size = */ graph->visited_hash_table.size,
+        /* .keys = */ calloc(sizeof(hash_set.keys[0]) * graph->visited_hash_table.size, 1)
+    };
+    struct ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]) * hash_set.size, 1);
+    bool * node_init = calloc(sizeof(node_init[0]) * hash_set.size, 1);
+
+    struct ggml_init_params params = {
+        /* .mem_size   = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ true
+    };
+
+    struct ggml_context * ctx_allocated = ggml_init(params);
+    struct ggml_context * ctx_unallocated = ggml_init(params);
+
+    // dup nodes
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
+    }
+
+    // allocate nodes
+    ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
+
+    //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
+
+    // copy data and init views
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_init_tensor(hash_set, node_copies, node_init, node);
+    }
+
+    // build graph copy
+    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false);
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct ggml_tensor * node_copy = node_copies[ggml_hash_find(hash_set, node)];
+        graph_copy->nodes[i] = node_copy;
+    }
+    graph_copy->n_nodes = graph->n_nodes;
+
+    free(hash_set.keys);
+    free(node_copies);
+    free(node_init);
+
+    return (struct ggml_backend_graph_copy) {
+        /* .buffer           = */ buffer,
+        /* .ctx_allocated    = */ ctx_allocated,
+        /* .ctx_unallocated  = */ ctx_unallocated,
+        /* .graph            = */ graph_copy,
+    };
+}
+
+void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
+    ggml_backend_buffer_free(copy.buffer);
+    ggml_free(copy.ctx_allocated);
+    ggml_free(copy.ctx_unallocated);
+}
+
+void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
+    struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
+    struct ggml_cgraph * g1 = graph;
+    struct ggml_cgraph * g2 = copy.graph;
+
+    assert(g1->n_nodes == g2->n_nodes);
+
+    for (int i = 0; i < g1->n_nodes; i++) {
+        //printf("eval %d/%d\n", i, g1->n_nodes);
+        struct ggml_tensor * t1 = g1->nodes[i];
+        struct ggml_tensor * t2 = g2->nodes[i];
+
+        assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
+
+        struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
+        struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
+
+        ggml_backend_graph_compute(backend1, &g1v);
+        ggml_backend_graph_compute(backend2, &g2v);
+
+        if (ggml_is_view_op(t1->op)) {
+            continue;
+        }
+
+        // compare results, calculate rms etc
+        if (!callback(i, t1, t2, user_data)) {
+            break;
+        }
+    }
+
+    ggml_backend_graph_copy_free(copy);
+}

+ 62 - 17
ggml-backend.h

@@ -7,41 +7,44 @@
 extern "C" {
 #endif
 
+    typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+    typedef struct ggml_backend * ggml_backend_t;
+    typedef void * ggml_backend_graph_plan_t;
+
     //
     // Backend buffer
     //
 
-    struct ggml_backend_buffer;
-    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+    // buffer type
+    GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
+    GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
+    GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
+    GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
 
-    // backend buffer functions
+    // buffer
     GGML_API void   ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
-    GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
     GGML_API void * ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
     GGML_API size_t ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
-    GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
     GGML_API void   ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-    GGML_API void   ggml_backend_buffer_free_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
+    GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
 
     //
     // Backend
     //
 
-    struct ggml_backend;
-    typedef struct ggml_backend * ggml_backend_t;
-    typedef void * ggml_backend_graph_plan_t;
-
-    GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
 
     GGML_API const char * ggml_backend_name(ggml_backend_t backend);
     GGML_API void         ggml_backend_free(ggml_backend_t backend);
 
-    GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
-
-    GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_t      ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
+    GGML_API size_t                     ggml_backend_get_alignment(ggml_backend_t backend);
 
-    GGML_API void ggml_backend_tensor_set_async(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
-    GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
 
     GGML_API void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
     GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
@@ -57,6 +60,7 @@ extern "C" {
 
     // tensor copy between different backends
     GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
+    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy
 
     //
     // CPU backend
@@ -68,8 +72,23 @@ extern "C" {
     GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
 
     // Create a backend buffer from an existing pointer
-    GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
+    GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
+
+    GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
 
+    //
+    // Backend registry
+    //
+
+    // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
+
+    GGML_API size_t                     ggml_backend_reg_get_count(void);
+    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
+    GGML_API const char *               ggml_backend_reg_get_name(size_t i);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
+    GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
+    GGML_API ggml_backend_buffer_t      ggml_backend_reg_alloc_buffer(size_t i, size_t size);
 
     //
     // Backend scheduler
@@ -131,6 +150,32 @@ extern "C" {
             ggml_backend_sched_t sched,
             struct ggml_cgraph * graph);
 
+
+    //
+    // Utils
+    //
+
+    struct ggml_backend_graph_copy {
+        ggml_backend_buffer_t buffer;
+        struct ggml_context * ctx_allocated;
+        struct ggml_context * ctx_unallocated;
+        struct ggml_cgraph * graph;
+    };
+
+    // Copy a graph to a different backend
+    GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
+    GGML_API void                           ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
+
+    typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
+
+    // Compare the output of two backends
+    GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
+
+    // Tensor initialization
+    GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
+    GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+
+
 #ifdef  __cplusplus
 }
 #endif

文件差異過大導致無法顯示
+ 646 - 254
ggml-cuda.cu


+ 9 - 1
ggml-cuda.h

@@ -49,7 +49,15 @@ GGML_API int    ggml_cuda_get_device_count(void);
 GGML_API void   ggml_cuda_get_device_description(int device, char * description, size_t description_size);
 
 // backend API
-GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
+GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
+
+GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
+GGML_API int  ggml_backend_cuda_get_device(ggml_backend_t backend);
+
+GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
+
+// pinned host buffer for use with CPU backend for faster copies between CPU and GPU
+GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
 
 #ifdef  __cplusplus
 }

+ 1 - 1
ggml-impl.h

@@ -232,7 +232,7 @@ bool   ggml_hash_contains      (const struct ggml_hash_set hash_set, struct ggml
 // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
 size_t ggml_hash_find          (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
 
-// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
+// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
 size_t ggml_hash_insert        (      struct ggml_hash_set hash_set, struct ggml_tensor * key);
 
 // return index, asserts if table is full

+ 6 - 0
ggml-metal.h

@@ -99,6 +99,12 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
 GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
 
 GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
+GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
+
+// helper to check if the device supports a specific family
+// ideally, the user code should be doing these checks
+// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
 
 #ifdef __cplusplus
 }

+ 479 - 155
ggml-metal.m

@@ -62,6 +62,8 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(mul);
     GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
+    GGML_METAL_DECL_KERNEL(div);
+    GGML_METAL_DECL_KERNEL(div_row);
     GGML_METAL_DECL_KERNEL(scale);
     GGML_METAL_DECL_KERNEL(scale_4);
     GGML_METAL_DECL_KERNEL(silu);
@@ -112,10 +114,24 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
     GGML_METAL_DECL_KERNEL(rope_f32);
     GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(im2col_f16);
+    GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
+    GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
@@ -126,6 +142,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
     GGML_METAL_DECL_KERNEL(concat);
     GGML_METAL_DECL_KERNEL(sqr);
+    GGML_METAL_DECL_KERNEL(sum_rows);
 
 #undef GGML_METAL_DECL_KERNEL
 };
@@ -169,12 +186,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
     }
 }
 
-
-
 struct ggml_metal_context * ggml_metal_init(int n_cb) {
     GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
 
-    id <MTLDevice> device;
+    id<MTLDevice> device;
     NSString * s;
 
 #if TARGET_OS_OSX
@@ -220,6 +235,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 
             NSString * sourcePath;
             NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
+
+            GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
+
             if (ggmlMetalPathResources) {
                 sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
             } else {
@@ -250,6 +268,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         }
     }
 
+#if TARGET_OS_OSX
+    // print MTL GPU family:
+    GGML_METAL_LOG_INFO("%s: GPU name:   %s\n", __func__, [[ctx->device name] UTF8String]);
+
+    // determine max supported GPU family
+    // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+    // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+    for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
+        if ([ctx->device supportsFamily:i]) {
+            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
+            break;
+        }
+    }
+
+    GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
+    if (ctx->device.maxTransferRate != 0) {
+        GGML_METAL_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
+    } else {
+        GGML_METAL_LOG_INFO("%s: maxTransferRate               = built-in GPU\n", __func__);
+    }
+#endif
+
     // load kernels
     {
         NSError * error = nil;
@@ -271,6 +312,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(add_row);
         GGML_METAL_ADD_KERNEL(mul);
         GGML_METAL_ADD_KERNEL(mul_row);
+        GGML_METAL_ADD_KERNEL(div);
+        GGML_METAL_ADD_KERNEL(div_row);
         GGML_METAL_ADD_KERNEL(scale);
         GGML_METAL_ADD_KERNEL(scale_4);
         GGML_METAL_ADD_KERNEL(silu);
@@ -322,11 +365,25 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
             GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
         }
         GGML_METAL_ADD_KERNEL(rope_f32);
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(im2col_f16);
+        GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
+        GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
@@ -337,33 +394,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
         GGML_METAL_ADD_KERNEL(concat);
         GGML_METAL_ADD_KERNEL(sqr);
+        GGML_METAL_ADD_KERNEL(sum_rows);
 
 #undef GGML_METAL_ADD_KERNEL
     }
 
-#if TARGET_OS_OSX
-    // print MTL GPU family:
-    GGML_METAL_LOG_INFO("%s: GPU name:   %s\n", __func__, [[ctx->device name] UTF8String]);
-
-    // determine max supported GPU family
-    // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
-    // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
-    for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
-        if ([ctx->device supportsFamily:i]) {
-            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
-            break;
-        }
-    }
-
-    GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",        __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
-    GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
-    if (ctx->device.maxTransferRate != 0) {
-        GGML_METAL_LOG_INFO("%s: maxTransferRate               = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
-    } else {
-        GGML_METAL_LOG_INFO("%s: maxTransferRate               = built-in GPU\n", __func__);
-    }
-#endif
-
     return ctx;
 }
 
@@ -377,6 +412,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(add_row);
     GGML_METAL_DEL_KERNEL(mul);
     GGML_METAL_DEL_KERNEL(mul_row);
+    GGML_METAL_DEL_KERNEL(div);
+    GGML_METAL_DEL_KERNEL(div_row);
     GGML_METAL_DEL_KERNEL(scale);
     GGML_METAL_DEL_KERNEL(scale_4);
     GGML_METAL_DEL_KERNEL(silu);
@@ -428,11 +465,25 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
         GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
     }
     GGML_METAL_DEL_KERNEL(rope_f32);
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
     GGML_METAL_DEL_KERNEL(im2col_f16);
+    GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
+    GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
     GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
@@ -443,6 +494,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
     GGML_METAL_DEL_KERNEL(concat);
     GGML_METAL_DEL_KERNEL(sqr);
+    GGML_METAL_DEL_KERNEL(sum_rows);
 
 #undef GGML_METAL_DEL_KERNEL
 
@@ -486,6 +538,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
     return ctx->concur_list;
 }
 
+// temporarily defined here for compatibility between ggml-backend and the old API
+struct ggml_backend_metal_buffer_context {
+    void * data;
+
+    id<MTLBuffer> metal;
+};
+
 // finds the Metal buffer that contains the tensor data on the GPU device
 // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
 // Metal buffer based on the host memory pointer
@@ -495,8 +554,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
 
     const int64_t tsize = ggml_nbytes(t);
 
-    if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
-        ctx = t->buffer->backend->context;
+    // compatibility with ggml-backend
+    if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
+        struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
+
+        const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
+
+        GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
+
+        *offs = (size_t) ioffs;
+
+        return buf_ctx->metal;
     }
 
     // find the view that contains the tensor fully
@@ -721,6 +789,51 @@ void ggml_metal_graph_find_concurrency(
     }
 }
 
+static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
+    switch (op->op) {
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_GELU:
+                    return true;
+                default:
+                    return false;
+            }
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_CONCAT:
+        case GGML_OP_ADD:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_SCALE:
+        case GGML_OP_SQR:
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_NORM:
+        case GGML_OP_ALIBI:
+        case GGML_OP_ROPE:
+        case GGML_OP_IM2COL:
+        case GGML_OP_ARGSORT:
+        case GGML_OP_DUP:
+        case GGML_OP_CPY:
+        case GGML_OP_CONT:
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
+            return true;
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_GET_ROWS:
+            {
+                return op->ne[0] % 4 == 0;
+            }
+        default:
+            return false;
+    }
+}
 void ggml_metal_graph_compute(
         struct ggml_metal_context * ctx,
                struct ggml_cgraph * gf) {
@@ -791,6 +904,8 @@ void ggml_metal_graph_compute(
                         } break;
                 }
 
+                GGML_ASSERT(ggml_metal_supports_op(dst));
+
                 const int64_t  ne00 = src0 ? src0->ne[0] : 0;
                 const int64_t  ne01 = src0 ? src0->ne[1] : 0;
                 const int64_t  ne02 = src0 ? src0->ne[2] : 0;
@@ -883,6 +998,8 @@ void ggml_metal_graph_compute(
                             [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_ADD:
+                    case GGML_OP_MUL:
+                    case GGML_OP_DIV:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
                             GGML_ASSERT(ggml_is_contiguous(src1));
@@ -896,11 +1013,21 @@ void ggml_metal_graph_compute(
                                 GGML_ASSERT(ne11 == 1);
 
                                 nb = ne00 / 4;
-                                [encoder setComputePipelineState:ctx->pipeline_add_row];
+                                switch (dst->op) {
+                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
+                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
+                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
+                                    default: GGML_ASSERT(false);
+                                }
 
                                 bcast_row = true;
                             } else {
-                                [encoder setComputePipelineState:ctx->pipeline_add];
+                                switch (dst->op) {
+                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
+                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
+                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
+                                    default: GGML_ASSERT(false);
+                                }
                             }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -941,31 +1068,6 @@ void ggml_metal_graph_compute(
                                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             }
                         } break;
-                    case GGML_OP_MUL:
-                        {
-                            GGML_ASSERT(ggml_is_contiguous(src0));
-                            GGML_ASSERT(ggml_is_contiguous(src1));
-
-                            // utilize float4
-                            GGML_ASSERT(ne00 % 4 == 0);
-                            const int64_t nb = ne00/4;
-
-                            if (ggml_nelements(src1) == ne10) {
-                                // src1 is a row
-                                GGML_ASSERT(ne11 == 1);
-                                [encoder setComputePipelineState:ctx->pipeline_mul_row];
-                            } else {
-                                [encoder setComputePipelineState:ctx->pipeline_mul];
-                            }
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&nb     length:sizeof(nb) atIndex:3];
-
-                            const int64_t n = ggml_nelements(dst)/4;
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                        } break;
                     case GGML_OP_SCALE:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1038,6 +1140,40 @@ void ggml_metal_graph_compute(
                             const int64_t n = ggml_nelements(dst);
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
+                    case GGML_OP_SUM_ROWS:
+                        {
+                            GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+
+                            [encoder setComputePipelineState:ctx->pipeline_sum_rows];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:19];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:20];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:21];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:22];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:23];
+                            [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:24];
+                            [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:25];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
                     case GGML_OP_SOFT_MAX:
                         {
                             int nth = 32; // SIMD width
@@ -1092,13 +1228,17 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL_MAT:
                         {
                             GGML_ASSERT(ne00 == ne10);
-                            GGML_ASSERT(ne03 == ne13);
 
-                            const uint gqa = ne12/ne02;
+                            // TODO: assert that dim2 and dim3 are contiguous
+                            GGML_ASSERT(ne12 % ne02 == 0);
+                            GGML_ASSERT(ne13 % ne03 == 0);
+
+                            const uint r2 = ne12/ne02;
+                            const uint r3 = ne13/ne03;
 
                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                             // to the matrix-vector kernel
-                            int ne11_mm_min = src0t == GGML_TYPE_F16 ? 1 : 16;
+                            int ne11_mm_min = 1;
 
 #if 0
                             // the numbers below are measured on M2 Ultra for 7B and 13B models
@@ -1159,9 +1299,10 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
                                 [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
                                 [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
-                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:13];
+                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
+                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                             } else {
                                 int nth0 = 32;
                                 int nth1 = 1;
@@ -1197,90 +1338,60 @@ void ggml_metal_graph_compute(
                                         } break;
                                     case GGML_TYPE_Q4_0:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
                                         } break;
                                     case GGML_TYPE_Q4_1:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
                                         } break;
                                     case GGML_TYPE_Q5_0:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
                                         } break;
                                     case GGML_TYPE_Q5_1:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
                                         } break;
                                     case GGML_TYPE_Q8_0:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
                                         } break;
                                     case GGML_TYPE_Q2_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
                                         } break;
                                     case GGML_TYPE_Q3_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
                                         } break;
                                     case GGML_TYPE_Q4_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 4; //1;
                                             nth1 = 8; //32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
                                         } break;
                                     case GGML_TYPE_Q6_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1309,34 +1420,127 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
-                                [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
+                                [encoder setBytes:&r2   length:sizeof(r2)   atIndex:17];
+                                [encoder setBytes:&r3   length:sizeof(r3)   atIndex:18];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
                                     src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
                                     src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q4_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #else
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #endif
                                 }
                                 else if (src0t == GGML_TYPE_Q5_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q6_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
                                     int64_t ny = (ne11 + nrows - 1)/nrows;
-                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                             }
                         } break;
+                    case GGML_OP_MUL_MAT_ID:
+                        {
+                            //GGML_ASSERT(ne00 == ne10);
+                            //GGML_ASSERT(ne03 == ne13);
+
+                            GGML_ASSERT(src0t == GGML_TYPE_I32);
+
+                            const int n_as = ne00;
+
+                            // TODO: make this more general
+                            GGML_ASSERT(n_as <= 8);
+
+                            struct ggml_tensor * src2 = gf->nodes[i]->src[2];
+
+                            const int64_t  ne20 = src2 ? src2->ne[0] : 0;
+                            const int64_t  ne21 = src2 ? src2->ne[1] : 0;
+                            const int64_t  ne22 = src2 ? src2->ne[2] : 0;
+                            const int64_t  ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+
+                            const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
+                            const uint64_t nb21 = src2 ? src2->nb[1] : 0;
+                            const uint64_t nb22 = src2 ? src2->nb[2] : 0;
+                            const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
+
+                            const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+
+                            GGML_ASSERT(!ggml_is_transposed(src2));
+                            GGML_ASSERT(!ggml_is_transposed(src1));
+
+                            GGML_ASSERT(ne20 % 32 == 0);
+                            // !!!!!!!!! TODO: this assert is probably required but not sure!
+                            //GGML_ASSERT(ne20 >= 64);
+                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+                            const uint r2 = ne12/ne22;
+                            const uint r3 = ne13/ne23;
+
+                            // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+                            // to the matrix-vector kernel
+                            int ne11_mm_min = 0;
+
+                            const int idx = ((int32_t *) dst->op_params)[0];
+
+                            // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+                            // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                                ne11 > ne11_mm_min) {
+                                switch (src2->type) {
+                                    case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32];  break;
+                                    case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32];  break;
+                                    case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
+                                    case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
+                                    case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
+                                    case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
+                                    case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
+                                    case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
+                                    case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
+                                    case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
+                                    case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
+                                    case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
+                                    default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
+                                }
+                                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
+                                [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:3];
+                                [encoder setBytes:&ne22    length:sizeof(ne22) atIndex:4];
+                                [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:5];
+                                [encoder setBytes:&nb22    length:sizeof(nb22) atIndex:6];
+                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
+                                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8];
+                                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9];
+                                [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
+                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
+                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
+                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
+                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
+                                [encoder setBytes:&idx     length:sizeof(idx)  atIndex:15];
+                                // TODO: how to make this an array? read Metal docs
+                                for (int j = 0; j < n_as; ++j) {
+                                    struct ggml_tensor * src_cur = dst->src[2 + j];
+
+                                    size_t offs_src_cur = 0;
+                                    id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
+                                }
+
+                                [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            }
+                        } break;
                     case GGML_OP_GET_ROWS:
                         {
                             switch (src0->type) {
@@ -1560,6 +1764,27 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
                         } break;
+                    case GGML_OP_ARGSORT:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+                            GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+                            const int nrows = ggml_nrows(src0);
+
+                            enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+                            switch (order) {
+                                case GGML_SORT_ASC:  [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc];  break;
+                                case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
+                                default: GGML_ASSERT(false);
+                            };
+
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
+                        } break;
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
@@ -1655,81 +1880,150 @@ void ggml_metal_graph_compute(
 
 // backend interface
 
-static const char * ggml_backend_metal_name(ggml_backend_t backend) {
-    return "Metal";
+static id<MTLDevice> g_backend_device = nil;
+static int g_backend_device_ref_count = 0;
 
-    UNUSED(backend);
+static id<MTLDevice> ggml_backend_metal_get_device(void) {
+    if (g_backend_device == nil) {
+        g_backend_device = MTLCreateSystemDefaultDevice();
+    }
+
+    g_backend_device_ref_count++;
+
+    return g_backend_device;
 }
 
-static void ggml_backend_metal_free(ggml_backend_t backend) {
-    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
-    ggml_metal_free(ctx);
-    free(backend);
+static void ggml_backend_metal_free_device(void) {
+    assert(g_backend_device_ref_count > 0);
+
+    g_backend_device_ref_count--;
+
+    if (g_backend_device_ref_count == 0) {
+        [g_backend_device release];
+        g_backend_device = nil;
+    }
 }
 
 static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
-    return (void *)buffer->context;
+    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+    return ctx->data;
 }
 
 static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    free(buffer->context);
+    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+    [ctx->metal release];
+    ggml_backend_metal_free_device();
+
+    free(ctx->data);
+    free(ctx);
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy((char *)tensor->data + offset, data, size);
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+
     UNUSED(buffer);
 }
 
 static struct ggml_backend_buffer_i metal_backend_buffer_i = {
-    /* .free_buffer    = */ ggml_backend_metal_buffer_free_buffer,
-    /* .get_base       = */ ggml_backend_metal_buffer_get_base,
-    /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
-    /* .init_tensor    = */ NULL, // no initialization required
-    /* .free_tensor    = */ NULL, // no cleanup required
+    /* .free_buffer     = */ ggml_backend_metal_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_metal_buffer_get_base,
+    /* .init_tensor     = */ NULL,
+    /* .set_tensor      = */ ggml_backend_metal_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_metal_buffer_get_tensor,
+    /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
+    /* .cpy_tensor_to   = */ ggml_backend_metal_buffer_cpy_tensor_to,
 };
 
-static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
-    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
+
+    const size_t size_page = sysconf(_SC_PAGESIZE);
 
-    void * data = ggml_metal_host_malloc(size);
+    size_t size_aligned = size;
+    if ((size_aligned % size_page) != 0) {
+        size_aligned += (size_page - (size_aligned % size_page));
+    }
 
-    // TODO: set proper name of the buffers
-    ggml_metal_add_buffer(ctx, "backend", data, size, 0);
+    ctx->data  = ggml_metal_host_malloc(size);
+    ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
+                    length:size_aligned
+                    options:MTLResourceStorageModeShared
+                    deallocator:nil];
 
-    return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
+    return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
 }
 
-static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
+static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     return 32;
-    UNUSED(backend);
+    UNUSED(buft);
 }
 
-static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
 
-    memcpy((char *)tensor->data + offset, data, size);
-
-    UNUSED(backend);
+    GGML_UNUSED(buft);
 }
 
-static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-
-    memcpy(data, (const char *)tensor->data + offset, size);
+ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
+        /* .iface = */ {
+            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_metal_buffer_type_get_alignment,
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
+        },
+        /* .context = */ NULL,
+    };
 
-    UNUSED(backend);
+    return &ggml_backend_buffer_type_metal;
 }
 
-static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
+static const char * ggml_backend_metal_name(ggml_backend_t backend) {
+    return "Metal";
+
     UNUSED(backend);
 }
 
-static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+static void ggml_backend_metal_free(ggml_backend_t backend) {
+    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+    ggml_metal_free(ctx);
+    free(backend);
+}
 
+static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
     UNUSED(backend);
 }
 
-static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
+static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_metal_buffer_type();
 
     UNUSED(backend);
 }
@@ -1741,32 +2035,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
 }
 
 static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-    return true;
+    return ggml_metal_supports_op(op);
+
     UNUSED(backend);
-    UNUSED(op);
 }
 
 static struct ggml_backend_i metal_backend_i = {
-    /* .get_name            = */ ggml_backend_metal_name,
-    /* .free                = */ ggml_backend_metal_free,
-    /* .alloc_buffer        = */ ggml_backend_metal_alloc_buffer,
-    /* .get_alignment       = */ ggml_backend_metal_get_alignment,
-    /* .set_tensor_async    = */ ggml_backend_metal_set_tensor_async,
-    /* .get_tensor_async    = */ ggml_backend_metal_get_tensor_async,
-    /* .synchronize         = */ ggml_backend_metal_synchronize,
-    /* .cpy_tensor_from     = */ ggml_backend_metal_cpy_tensor_from,
-    /* .cpy_tensor_to       = */ ggml_backend_metal_cpy_tensor_to,
-    /* .graph_plan_create   = */ NULL, // the metal implementation does not require creating graph plans atm
-    /* .graph_plan_free     = */ NULL,
-    /* .graph_plan_compute  = */ NULL,
-    /* .graph_compute       = */ ggml_backend_metal_graph_compute,
-    /* .supports_op         = */ ggml_backend_metal_supports_op,
+    /* .get_name                = */ ggml_backend_metal_name,
+    /* .free                    = */ ggml_backend_metal_free,
+    /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_from_async   = */ NULL,
+    /* .cpy_tensor_to_async     = */ NULL,
+    /* .synchronize             = */ ggml_backend_metal_synchronize,
+    /* .graph_plan_create       = */ NULL, // the metal implementation does not require creating graph plans atm
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_metal_graph_compute,
+    /* .supports_op             = */ ggml_backend_metal_supports_op,
 };
 
+// TODO: make a common log callback for all backends in ggml-backend
+static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
+    fprintf(stderr, "%s", msg);
+
+    UNUSED(level);
+    UNUSED(user_data);
+}
+
 ggml_backend_t ggml_backend_metal_init(void) {
-    struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+    ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
+
+    struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
 
-    ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
+    if (ctx == NULL) {
+        return NULL;
+    }
 
     ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
 
@@ -1783,7 +2088,26 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
 }
 
 void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
     struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
 
     ggml_metal_set_n_cb(ctx, n_cb);
 }
+
+bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
+    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+
+    return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
+}
+
+ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
+
+ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
+    return ggml_backend_metal_init();
+
+    GGML_UNUSED(params);
+    GGML_UNUSED(user_data);
+}

文件差異過大導致無法顯示
+ 431 - 121
ggml-metal.metal


+ 325 - 89
ggml.c

@@ -233,24 +233,6 @@ inline static void * ggml_aligned_malloc(size_t size) {
 #define UNUSED GGML_UNUSED
 #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
 
-//
-// tensor access macros
-//
-
-#define GGML_TENSOR_UNARY_OP_LOCALS \
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
-
-#define GGML_TENSOR_BINARY_OP_LOCALS \
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
-
 #if defined(GGML_USE_ACCELERATE)
 #include <Accelerate/Accelerate.h>
 #if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
@@ -1613,6 +1595,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GROUP_NORM",
 
     "MUL_MAT",
+    "MUL_MAT_ID",
     "OUT_PROD",
 
     "SCALE",
@@ -1640,6 +1623,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "POOL_1D",
     "POOL_2D",
     "UPSCALE",
+    "ARGSORT",
 
     "FLASH_ATTN",
     "FLASH_FF",
@@ -1666,7 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
+static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1695,6 +1679,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "group_norm(x)",
 
     "X*Y",
+    "X[i]*Y",
     "X*Y",
 
     "x*v",
@@ -1722,6 +1707,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "pool_1d(x)",
     "pool_2d(x)",
     "upscale(x)",
+    "argsort(x)",
 
     "flash_attn(x)",
     "flash_ff(x)",
@@ -1748,10 +1734,28 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
+static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
+
+static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
+    "ABS",
+    "SGN",
+    "NEG",
+    "STEP",
+    "TANH",
+    "ELU",
+    "RELU",
+    "GELU",
+    "GELU_QUICK",
+    "SILU",
+    "LEAKY",
+};
+
+static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11");
+
+
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
 
@@ -1771,6 +1775,7 @@ static void ggml_setup_op_has_task_pass(void) {
 
         p[GGML_OP_ACC                    ] = true;
         p[GGML_OP_MUL_MAT                ] = true;
+        p[GGML_OP_MUL_MAT_ID             ] = true;
         p[GGML_OP_OUT_PROD               ] = true;
         p[GGML_OP_SET                    ] = true;
         p[GGML_OP_GET_ROWS_BACK          ] = true;
@@ -2023,6 +2028,20 @@ const char * ggml_op_symbol(enum ggml_op op) {
     return GGML_OP_SYMBOL[op];
 }
 
+const char * ggml_unary_op_name(enum ggml_unary_op op) {
+    return GGML_UNARY_OP_NAME[op];
+}
+
+const char * ggml_op_desc(const struct ggml_tensor * t) {
+    if (t->op == GGML_OP_UNARY) {
+        enum ggml_unary_op uop = ggml_get_unary_op(t);
+        return ggml_unary_op_name(uop);
+    }
+    else {
+        return ggml_op_name(t->op);
+    }
+}
+
 size_t ggml_element_size(const struct ggml_tensor * tensor) {
     return ggml_type_size(tensor->type);
 }
@@ -3154,9 +3173,7 @@ static struct ggml_tensor * ggml_add_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    // TODO: support less-strict constraint
-    //       GGML_ASSERT(ggml_can_repeat(b, a));
-    GGML_ASSERT(ggml_can_repeat_rows(b, a));
+    GGML_ASSERT(ggml_can_repeat(b, a));
 
     bool is_node = false;
 
@@ -3371,9 +3388,7 @@ static struct ggml_tensor * ggml_mul_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    // TODO: support less-strict constraint
-    //       GGML_ASSERT(ggml_can_repeat(b, a));
-    GGML_ASSERT(ggml_can_repeat_rows(b, a));
+    GGML_ASSERT(ggml_can_repeat(b, a));
 
     bool is_node = false;
 
@@ -3418,7 +3433,7 @@ static struct ggml_tensor * ggml_div_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_can_repeat(b, a));
 
     bool is_node = false;
 
@@ -4056,6 +4071,49 @@ struct ggml_tensor * ggml_mul_mat(
     return result;
 }
 
+// ggml_mul_mat_id
+
+struct ggml_tensor * ggml_mul_mat_id(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * as[],
+        struct ggml_tensor  * ids,
+        int                   id,
+        struct ggml_tensor  * b) {
+
+    int64_t n_as = ids->ne[0];
+
+    GGML_ASSERT(ids->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_vector(ids));
+    GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
+    GGML_ASSERT(id >= 0 && id < n_as);
+
+    bool is_node = false;
+
+    if (as[0]->grad || b->grad) {
+        is_node = true;
+    }
+
+    const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
+
+    ggml_set_op_params_i32(result, 0, id);
+
+    result->op   = GGML_OP_MUL_MAT_ID;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = ids;
+    result->src[1] = b;
+
+    for (int64_t i = 0; i < n_as; i++) {
+        struct ggml_tensor * a = as[i];
+        GGML_ASSERT(ggml_are_same_shape(as[0], a));
+        GGML_ASSERT(ggml_can_mul_mat(a, b));
+        GGML_ASSERT(!ggml_is_transposed(a));
+        result->src[i + 2] = a;
+    }
+
+    return result;
+}
+
 // ggml_out_prod
 
 struct ggml_tensor * ggml_out_prod(
@@ -4209,7 +4267,7 @@ struct ggml_tensor * ggml_set_2d_inplace(
         struct ggml_tensor *  b,
         size_t                nb1,
         size_t                offset) {
-    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
+    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
 }
 
 // ggml_cpy
@@ -5468,6 +5526,43 @@ struct ggml_tensor * ggml_upscale(
     return ggml_upscale_impl(ctx, a, scale_factor);
 }
 
+// ggml_argsort
+
+struct ggml_tensor * ggml_argsort(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_sort_order  order) {
+    bool is_node = false;
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, a->n_dims, a->ne);
+
+    ggml_set_op_params_i32(result, 0, (int32_t) order);
+
+    result->op   = GGML_OP_ARGSORT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_top_k
+
+struct ggml_tensor * ggml_top_k(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   k) {
+    GGML_ASSERT(a->ne[0] >= k);
+
+    struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_DESC);
+
+    result = ggml_view_4d(ctx, result,
+                k, result->ne[1], result->ne[2], result->ne[3],
+                   result->nb[1], result->nb[2], result->nb[3],
+                0);
+
+    return result;
+}
+
 // ggml_flash_attn
 
 struct ggml_tensor * ggml_flash_attn(
@@ -6827,7 +6922,7 @@ static void ggml_compute_forward_add_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -6860,16 +6955,19 @@ static void ggml_compute_forward_add_f32(
             const int64_t i13 = i03 % ne13;
             const int64_t i12 = i02 % ne12;
             const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
 
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
             float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
+            for (int64_t r = 0; r < nr0; ++r) {
 #ifdef GGML_USE_ACCELERATE
-            vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
+                vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
 #else
-            ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
+                ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 #endif
+            }
         }
     } else {
         // src1 is not contiguous
@@ -6886,8 +6984,9 @@ static void ggml_compute_forward_add_f32(
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
 
-            for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
             }
@@ -7607,7 +7706,7 @@ static void ggml_compute_forward_mul_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -7630,7 +7729,6 @@ static void ggml_compute_forward_mul_f32(
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
-    GGML_ASSERT(ne00 == ne10);
 
     if (nb10 == sizeof(float)) {
         for (int64_t ir = ith; ir < nr; ir += nth) {
@@ -7642,20 +7740,21 @@ static void ggml_compute_forward_mul_f32(
             const int64_t i13 = i03 % ne13;
             const int64_t i12 = i02 % ne12;
             const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
 
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
             float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
+            for (int64_t r = 0 ; r < nr0; ++r) {
 #ifdef GGML_USE_ACCELERATE
-            UNUSED(ggml_vec_mul_f32);
+                UNUSED(ggml_vec_mul_f32);
 
-            vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr,  1, ne00);
+                vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
 #else
-            ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
+                ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 #endif
-                // }
-            // }
+            }
         }
     } else {
         // src1 is not contiguous
@@ -7673,8 +7772,9 @@ static void ggml_compute_forward_mul_f32(
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
 
-            for (int64_t i0 = 0; i0 < ne00; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
+            for (int64_t i0 = 0; i0 < ne00; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
             }
@@ -7708,14 +7808,16 @@ static void ggml_compute_forward_div_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int nr  = ggml_nrows(src0);
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -7723,41 +7825,50 @@ static void ggml_compute_forward_div_f32(
     GGML_ASSERT(nb00 == sizeof(float));
 
     if (nb10 == sizeof(float)) {
-        for (int ir = 0; ir < nr; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
+            for (int64_t r = 0; r < nr0; ++r) {
 #ifdef GGML_USE_ACCELERATE
-            UNUSED(ggml_vec_div_f32);
+                UNUSED(ggml_vec_div_f32);
 
-            vDSP_vdiv(
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
-                    ne0);
+                vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
 #else
-            ggml_vec_div_f32(ne0,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+                ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 #endif
-                // }
-            // }
+            }
         }
     } else {
         // src1 is not contiguous
-        for (int ir = 0; ir < nr; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne00; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
             }
@@ -8203,7 +8314,7 @@ static void ggml_compute_forward_repeat_f16(
         return;
     }
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
     const int nr0 = (int)(ne0/ne00);
@@ -8348,6 +8459,7 @@ static void ggml_compute_forward_concat_f32(
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
+    const int nth = params->nth;
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -8357,7 +8469,7 @@ static void ggml_compute_forward_concat_f32(
     GGML_ASSERT(nb10 == sizeof(float));
 
     for (int i3 = 0; i3 < ne3; i3++) {
-        for (int i2 = ith; i2 < ne2; i2++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
             if (i2 < ne02) { // src0
                 for (int i1 = 0; i1 < ne1; i1++) {
                     for (int i0 = 0; i0 < ne0; i0++) {
@@ -9517,6 +9629,8 @@ static void ggml_compute_forward_mul_mat(
             char * wdata = params->wdata;
             const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
 
+            assert(params->wsize >= ne11*ne12*ne13*row_size);
+
             for (int64_t i13 = 0; i13 < ne13; ++i13) {
                 for (int64_t i12 = 0; i12 < ne12; ++i12) {
                     for (int64_t i11 = 0; i11 < ne11; ++i11) {
@@ -9618,6 +9732,26 @@ static void ggml_compute_forward_mul_mat(
     }
 }
 
+// ggml_compute_forward_mul_mat_id
+
+static void ggml_compute_forward_mul_mat_id(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * ids = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    const int id = ggml_get_op_params_i32(dst, 0);
+
+    const int a_id = ((int32_t *)ids->data)[id];
+
+    GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
+
+    const struct ggml_tensor * src0 = dst->src[a_id + 2];
+
+    ggml_compute_forward_mul_mat(params, src0, src1, dst);
+}
+
 // ggml_compute_forward_out_prod
 
 static void ggml_compute_forward_out_prod_f32(
@@ -12021,6 +12155,67 @@ static void ggml_compute_forward_upscale(
     }
 }
 
+// ggml_compute_forward_argsort
+
+static void ggml_compute_forward_argsort_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(nb0 == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
+
+    enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
+
+    for (int64_t i = ith; i < nr; i += nth) {
+        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+        const float * src_data = (float *)((char *) src0->data + i*nb01);
+
+        for (int64_t j = 0; j < ne0; j++) {
+            dst_data[j] = j;
+        }
+
+        // C doesn't have a functional sort, so we do a bubble sort instead
+        for (int64_t j = 0; j < ne0; j++) {
+            for (int64_t k = j + 1; k < ne0; k++) {
+                if ((order == GGML_SORT_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
+                    (order == GGML_SORT_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
+                    int32_t tmp = dst_data[j];
+                    dst_data[j] = dst_data[k];
+                    dst_data[k] = tmp;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_argsort(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_argsort_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_flash_attn
 
 static void ggml_compute_forward_flash_attn_f32(
@@ -13844,6 +14039,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
             } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                ggml_compute_forward_mul_mat_id(params, tensor);
+            } break;
         case GGML_OP_OUT_PROD:
             {
                 ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
@@ -13948,6 +14147,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_upscale(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_ARGSORT:
+            {
+                ggml_compute_forward_argsort(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 const int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -14598,6 +14801,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 zero_table);
                 }
             } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_OUT_PROD:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -14936,6 +15143,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_ARGSORT:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 struct ggml_tensor * flash_grad = NULL;
@@ -15296,12 +15507,8 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
     return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
 }
 
-struct ggml_cgraph * ggml_graph_view(struct ggml_context * ctx, struct ggml_cgraph * cgraph0, int i0, int i1) {
-    const size_t obj_size = sizeof(struct ggml_cgraph);
-    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size);
-    struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
-
-    *cgraph = (struct ggml_cgraph) {
+struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {
+    struct ggml_cgraph cgraph = {
         /*.size         =*/ 0,
         /*.n_nodes      =*/ i1 - i0,
         /*.n_leafs      =*/ 0,
@@ -15536,7 +15743,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 n_tasks = n_threads;
             } break;
         case GGML_OP_SUB:
-        case GGML_OP_DIV:
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_LOG:
@@ -15569,10 +15775,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                     {
                         n_tasks = n_threads;
                     } break;
+                default:
+                    GGML_ASSERT(false);
             }
             break;
         case GGML_OP_SILU_BACK:
         case GGML_OP_MUL:
+        case GGML_OP_DIV:
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
         case GGML_OP_RMS_NORM_BACK:
@@ -15610,6 +15819,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 }
 #endif
             } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                // FIXME: blas
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_OUT_PROD:
             {
                 n_tasks = n_threads;
@@ -15669,6 +15883,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = n_threads;
             } break;
+        case GGML_OP_ARGSORT:
+            {
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 n_tasks = n_threads;
@@ -15731,6 +15949,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = 1;
             } break;
+        case GGML_OP_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
         default:
             {
                 fprintf(stderr, "%s: op not implemented: ", __func__);
@@ -15927,6 +16149,23 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type);
                     }
                 } break;
+            case GGML_OP_MUL_MAT_ID:
+                {
+                    const struct ggml_tensor * a = node->src[2];
+                    const struct ggml_tensor * b = node->src[1];
+                    const enum ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                    if (ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
+                        if (a->type != GGML_TYPE_F32) {
+                            // here we need memory just for single 2D matrix from src0
+                            cur = ggml_type_size(GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
+                        }
+                    } else
+#endif
+                    if (b->type != vec_dot_type) {
+                        cur = ggml_type_size(vec_dot_type)*ggml_nelements(b)/ggml_blck_size(vec_dot_type);
+                    }
+                } break;
             case GGML_OP_OUT_PROD:
                 {
                     if (ggml_is_quantized(node->src[0]->type)) {
@@ -15962,9 +16201,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         GGML_ASSERT(false);
                     }
                 } break;
-            case GGML_OP_IM2COL:
-                {
-                } break;
             case GGML_OP_CONV_TRANSPOSE_2D:
                 {
                     const int64_t ne00 = node->src[0]->ne[0]; // W
@@ -17803,8 +18039,8 @@ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t *
             memcpy(&qh, &y[i].qh, sizeof(qh));
 
             for (int j = 0; j < QK5_0; j += 2) {
-                const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-                const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
+                const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
+                const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
 
                 // cast to 16 bins
                 const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
@@ -17833,8 +18069,8 @@ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t *
             memcpy(&qh, &y[i].qh, sizeof(qh));
 
             for (int j = 0; j < QK5_1; j += 2) {
-                const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-                const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
+                const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
+                const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
 
                 // cast to 16 bins
                 const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
@@ -18024,6 +18260,7 @@ struct gguf_kv {
 
 struct gguf_header {
     char magic[4];
+
     uint32_t version;
     uint64_t n_tensors; // GGUFv2
     uint64_t n_kv;      // GGUFv2
@@ -18113,7 +18350,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
 
         for (uint32_t i = 0; i < sizeof(magic); i++) {
             if (magic[i] != GGUF_MAGIC[i]) {
-                fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic);
+                fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
                 fclose(file);
                 return NULL;
             }
@@ -18128,7 +18365,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
     {
         strncpy(ctx->header.magic, magic, 4);
 
-
         ctx->kv    = NULL;
         ctx->infos = NULL;
         ctx->data  = NULL;

+ 49 - 4
ggml.h

@@ -283,6 +283,20 @@
     const type prefix##3 = (pointer)->array[3]; \
     GGML_UNUSED(prefix##3);
 
+#define GGML_TENSOR_UNARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
+#define GGML_TENSOR_BINARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
 #ifdef  __cplusplus
 extern "C" {
 #endif
@@ -381,6 +395,7 @@ extern "C" {
         GGML_OP_GROUP_NORM,
 
         GGML_OP_MUL_MAT,
+        GGML_OP_MUL_MAT_ID,
         GGML_OP_OUT_PROD,
 
         GGML_OP_SCALE,
@@ -407,8 +422,8 @@ extern "C" {
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
-
         GGML_OP_UPSCALE, // nearest interpolate
+        GGML_OP_ARGSORT,
 
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
@@ -448,7 +463,9 @@ extern "C" {
         GGML_UNARY_OP_GELU,
         GGML_UNARY_OP_GELU_QUICK,
         GGML_UNARY_OP_SILU,
-        GGML_UNARY_OP_LEAKY
+        GGML_UNARY_OP_LEAKY,
+
+        GGML_UNARY_OP_COUNT,
     };
 
     enum ggml_object_type {
@@ -631,6 +648,9 @@ extern "C" {
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
     GGML_API const char * ggml_op_symbol(enum ggml_op   op);
 
+    GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
+    GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
+
     GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
 
     GGML_API bool    ggml_is_quantized(enum ggml_type type);
@@ -1027,6 +1047,15 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // indirect matrix multiplication
+    //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
+    GGML_API struct ggml_tensor * ggml_mul_mat_id(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * as[],
+            struct ggml_tensor  * ids,
+            int                   id,
+            struct ggml_tensor  * b);
+
     // A: m columns, n rows,
     // B: p columns, n rows,
     // result is m columns, p rows
@@ -1520,6 +1549,23 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   scale_factor);
 
+    // sort rows
+    enum ggml_sort_order {
+        GGML_SORT_ASC,
+        GGML_SORT_DESC,
+    };
+
+    GGML_API struct ggml_tensor * ggml_argsort(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_sort_order  order);
+
+    // top k elements per row
+    GGML_API struct ggml_tensor * ggml_top_k(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   k);
+
     GGML_API struct ggml_tensor * ggml_flash_attn(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
@@ -1581,7 +1627,6 @@ extern "C" {
             int                   kh);
 
     // used in sam
-
     GGML_API struct ggml_tensor * ggml_add_rel_pos(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1756,7 +1801,7 @@ extern "C" {
     GGML_API struct ggml_cgraph * ggml_new_graph         (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
     GGML_API struct ggml_cgraph * ggml_new_graph_custom  (struct ggml_context * ctx, size_t size, bool grads);
     GGML_API struct ggml_cgraph * ggml_graph_dup         (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
-    GGML_API struct ggml_cgraph * ggml_graph_view        (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
+    GGML_API struct ggml_cgraph   ggml_graph_view        (struct ggml_cgraph * cgraph, int i0, int i1);
     GGML_API void                 ggml_graph_cpy         (struct ggml_cgraph * src, struct ggml_cgraph * dst);
     GGML_API void                 ggml_graph_reset       (struct ggml_cgraph * cgraph);  // zero grads
     GGML_API void                 ggml_graph_clear       (struct ggml_cgraph * cgraph);

+ 3 - 2
scripts/sync-ggml.sh

@@ -20,5 +20,6 @@ cp -rpv ../ggml/include/ggml/ggml.h         ./ggml.h
 cp -rpv ../ggml/include/ggml/ggml-alloc.h   ./ggml-alloc.h
 cp -rpv ../ggml/include/ggml/ggml-backend.h ./ggml-backend.h
 
-cp -rpv ../ggml/tests/test-opt.cpp    ./tests/test-opt.cpp
-cp -rpv ../ggml/tests/test-grad0.cpp  ./tests/test-grad0.cpp
+cp -rpv ../ggml/tests/test-opt.cpp         ./tests/test-opt.cpp
+cp -rpv ../ggml/tests/test-grad0.cpp       ./tests/test-grad0.cpp
+cp -rpv ../ggml/tests/test-backend-ops.cpp ./tests/test-backend-ops.cpp

+ 17 - 11
tests/CMakeLists.txt

@@ -22,26 +22,32 @@ endfunction()
 llama_build_and_test_executable(test-quantize-fns.cpp)
 llama_build_and_test_executable(test-quantize-perf.cpp)
 llama_build_and_test_executable(test-sampling.cpp)
+
 llama_build_executable(test-tokenizer-0-llama.cpp)
 llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
+
 llama_build_executable(test-tokenizer-0-falcon.cpp)
 llama_test_executable (test-tokenizer-0-falcon test-tokenizer-0-falcon.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
+
 llama_build_executable(test-tokenizer-1-llama.cpp)
-llama_test_executable (test-tokenizer-1-llama test-tokenizer-1-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
-llama_test_executable(test-tokenizer-1-baichuan test-tokenizer-1-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
+llama_test_executable (test-tokenizer-1-llama    test-tokenizer-1-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
+llama_test_executable (test-tokenizer-1-baichuan test-tokenizer-1-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
+
 llama_build_executable(test-tokenizer-1-bpe.cpp)
-llama_test_executable (test-tokenizer-1-falcon test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
-llama_test_executable(test-tokenizer-1-aquila test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
-llama_test_executable(test-tokenizer-1-mpt test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
-llama_test_executable(test-tokenizer-1-stablelm-3b-4e1t test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-stablelm-3b-4e1t.gguf)
-llama_test_executable(test-tokenizer-1-gpt-neox test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
-llama_test_executable(test-tokenizer-1-refact test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
-llama_test_executable(test-tokenizer-1-starcoder test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
-# llama_test_executable(test-tokenizer-1-bloom test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bloom.gguf) # BIG
+llama_test_executable (test-tokenizer-1-falcon           test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
+llama_test_executable (test-tokenizer-1-aquila           test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
+llama_test_executable (test-tokenizer-1-mpt              test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
+llama_test_executable (test-tokenizer-1-stablelm-3b-4e1t test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-stablelm-3b-4e1t.gguf)
+llama_test_executable (test-tokenizer-1-gpt-neox         test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
+llama_test_executable (test-tokenizer-1-refact           test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
+llama_test_executable (test-tokenizer-1-starcoder        test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
+# llama_test_executable (test-tokenizer-1-bloom test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bloom.gguf) # BIG
+
 llama_build_and_test_executable(test-grammar-parser.cpp)
 llama_build_and_test_executable(test-llama-grammar.cpp)
-llama_build_and_test_executable(test-grad0.cpp) # SLOW
+llama_build_and_test_executable(test-grad0.cpp)
 # llama_build_and_test_executable(test-opt.cpp) # SLOW
+llama_build_and_test_executable(test-backend-ops.cpp)
 
 llama_build_and_test_executable(test-rope.cpp)
 

+ 1357 - 0
tests/test-backend-ops.cpp

@@ -0,0 +1,1357 @@
+#include <ggml.h>
+#include <ggml-alloc.h>
+#include <ggml-backend.h>
+#include <ggml-backend-impl.h>
+#include <algorithm>
+#include <array>
+#include <cfloat>
+#include <cstring>
+#include <functional>
+#include <memory>
+#include <random>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string>
+#include <thread>
+#include <vector>
+
+
+static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
+    size_t size = ggml_nelements(tensor);
+    std::vector<float> data(size);
+
+    std::random_device rd;
+
+#if 0
+    std::default_random_engine generator(rd());
+    std::uniform_real_distribution<float> distribution(min, max);
+
+    for (size_t i = 0; i < size; i++) {
+        data[i] = distribution(generator);
+    }
+#endif
+    auto init_thread = [&](size_t start, size_t end) {
+        std::default_random_engine generator(rd());
+        std::uniform_real_distribution<float> distribution(min, max);
+
+        for (size_t i = start; i < end; i++) {
+            data[i] = distribution(generator);
+        }
+    };
+
+    size_t n_threads = std::thread::hardware_concurrency();
+    std::vector<std::thread> threads;
+    threads.reserve(n_threads);
+    for (size_t i = 0; i < n_threads; i++) {
+        size_t start =     i*size/n_threads;
+        size_t end   = (i+1)*size/n_threads;
+        threads.emplace_back(init_thread, start, end);
+    }
+    for (auto & t : threads) {
+        t.join();
+    }
+
+    if (tensor->type == GGML_TYPE_F32) {
+        ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
+    } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
+        GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
+        std::vector<uint8_t> dataq(ggml_type_size(tensor->type)*size/ggml_blck_size(tensor->type));
+        int64_t hist[16];
+        ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size, hist);
+        ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
+    } else {
+        GGML_ASSERT(false);
+    }
+}
+
+static std::vector<float> tensor_to_float(const ggml_tensor * t) {
+    std::vector<float> tv;
+    tv.reserve(ggml_nelements(t));
+
+    std::vector<uint8_t> buf(ggml_nbytes(t));
+    ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
+
+    // access elements by index to avoid gaps in views
+    for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
+        for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
+            for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
+                for (int64_t i0 = 0; i0 < t->ne[0]; i0++) {
+                    size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0*t->nb[0];
+                    float v;
+                    if (t->type == GGML_TYPE_F16) {
+                        v = (float) ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]);
+                    } else if (t->type == GGML_TYPE_F32) {
+                        v = *(float *) &buf[i];
+                    } else if (t->type == GGML_TYPE_I32) {
+                        v = *(int32_t *) &buf[i];
+                    } else {
+                        GGML_ASSERT(false);
+                    }
+                    tv.push_back(v);
+                }
+            }
+        }
+    }
+
+    return tv;
+}
+
+/*
+static double cosine_similarity(const float * v1, const float * v2, size_t n) {
+    double dot = 0.0;
+    double mag1 = 0.0;
+    double mag2 = 0.0;
+
+    for (size_t i = 0; i < n; i++) {
+        if (std::isnan(v1[i]) || std::isnan(v2[i])) {
+            return -1.0f;
+        }
+        if (std::isinf(v1[i]) && std::isinf(v2[i])) {
+            continue;
+        }
+        dot  += v1[i]*v2[i];
+        mag1 += v1[i]*v1[i];
+        mag2 += v2[i]*v2[i];
+    }
+
+    return dot/sqrt(mag1*mag2);
+}
+
+static float distance(const float * v1, const float * v2, size_t n) {
+    double d = 0.0;
+
+    for (size_t i = 0; i < n; i++) {
+        if (std::isnan(v1[i]) || std::isnan(v2[i])) {
+            return INFINITY;
+        }
+        if (std::isinf(v1[i]) && std::isinf(v2[i])) {
+            continue;
+        }
+        d += (v1[i] - v2[i])*(v1[i] - v2[i]);
+    }
+
+    return sqrt(d);
+}
+
+static float vec_len(const float * v, size_t n) {
+    double d = 0.0;
+
+    for (size_t i = 0; i < n; i++) {
+        if (std::isnan(v[i])) {
+            return INFINITY;
+        }
+        if (std::isinf(v[i])) {
+            continue;
+        }
+        d += v[i]*v[i];
+    }
+
+    return sqrt(d);
+}
+*/
+
+// normalized mean squared error = mse(a, b) / mse(a, 0)
+static double nmse(const float * a, const float * b, size_t n) {
+    double mse_a_b = 0.0;
+    double mse_a_0 = 0.0;
+
+    for (size_t i = 0; i < n; i++) {
+        float a_i = a[i];
+        float b_i = b[i];
+
+        mse_a_b += (a_i - b_i) * (a_i - b_i);
+        mse_a_0 += a_i * a_i;
+    }
+
+    return mse_a_b / mse_a_0;
+}
+
+// utils for printing the variables of the test cases
+#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
+
+template<typename T>
+static std::string var_to_str(const T & x) {
+    return std::to_string(x);
+}
+
+template<typename T, size_t N>
+static std::string var_to_str(const T (&x)[N]) {
+    std::string s = "[";
+    for (size_t i = 0; i < N; i++) {
+        if (i > 0) {
+            s += ",";
+        }
+        s += var_to_str(x[i]);
+    }
+    s += "]";
+    return s;
+}
+
+template<typename T, size_t N>
+static std::string var_to_str(const std::array<T, N> & x) {
+    std::string s = "[";
+    for (size_t i = 0; i < N; i++) {
+        if (i > 0) {
+            s += ",";
+        }
+        s += var_to_str(x[i]);
+    }
+    s += "]";
+    return s;
+}
+
+//static std::string var_to_str(ggml_unary_op unary_op) {
+//    return ggml_unary_op_name(unary_op);
+//}
+
+static std::string var_to_str(ggml_type type) {
+    return ggml_type_name(type);
+}
+
+#define VARS_TO_STR1(a) VAR_TO_STR(a)
+#define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
+#define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
+#define VARS_TO_STR4(a, b, c, d) VAR_TO_STR(a) + "," + VARS_TO_STR3(b, c, d)
+#define VARS_TO_STR5(a, b, c, d, e) VAR_TO_STR(a) + "," + VARS_TO_STR4(b, c, d, e)
+#define VARS_TO_STR6(a, b, c, d, e, f) VAR_TO_STR(a) + "," + VARS_TO_STR5(b, c, d, e, f)
+#define VARS_TO_STR7(a, b, c, d, e, f, g) VAR_TO_STR(a) + "," + VARS_TO_STR6(b, c, d, e, f, g)
+#define VARS_TO_STR8(a, b, c, d, e, f, g, h) VAR_TO_STR(a) + "," + VARS_TO_STR7(b, c, d, e, f, g, h)
+#define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + "," + VARS_TO_STR8(b, c, d, e, f, g, h, i)
+#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
+#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
+
+
+// accept FLT_MAX as infinity
+static bool isinf_or_max(float f) {
+    return std::isinf(f) || f == FLT_MAX || f == -FLT_MAX;
+}
+
+static bool ggml_is_view_op(enum ggml_op op) {
+    return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
+}
+
+struct test_case {
+    virtual ~test_case() {}
+
+    virtual std::string vars() {
+        return "";
+    }
+
+    virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
+
+    virtual double max_nmse_err() {
+        return 1e-6;
+    }
+
+    virtual void initialize_tensors(ggml_context * ctx) {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t);
+        }
+    }
+
+    virtual size_t op_size(ggml_tensor * t) {
+        size_t size = ggml_nbytes(t);
+        // add source tensors
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            if (t->src[i] != NULL) {
+                size += ggml_nbytes(t->src[i]);
+            }
+        }
+        return size;
+    }
+
+    bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) {
+        ggml_init_params params = {
+            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
+            /* .mem_base = */ NULL,
+            /* .no_alloc = */ true,
+        };
+        ggml_context * ctx = ggml_init(params);
+
+        ggml_tensor * out = build_graph(ctx);
+
+        if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
+            //printf("  %s: skipping\n", ggml_op_desc(out));
+            ggml_free(ctx);
+            return true;
+        }
+
+        printf("  %s(%s): ", ggml_op_desc(out), vars().c_str());
+        fflush(stdout);
+
+        // check if backends support op
+        for (ggml_backend_t backend : {backend1, backend2}) {
+            if (!ggml_backend_supports_op(backend, out)) {
+                printf("not supported\n");
+                ggml_free(ctx);
+                return true;
+            }
+        }
+
+        // allocate
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
+
+        // build graph
+        ggml_cgraph * gf = ggml_new_graph(ctx);
+        ggml_build_forward_expand(gf, out);
+
+        // randomize tensors
+        initialize_tensors(ctx);
+
+        // compare
+        struct callback_userdata {
+            bool   ok;
+            double max_err;
+        };
+
+        callback_userdata ud {
+            true,
+            max_nmse_err(),
+        };
+
+        auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
+            std::vector<float> f1 = tensor_to_float(t1);
+            std::vector<float> f2 = tensor_to_float(t2);
+            callback_userdata * ud = (callback_userdata *) user_data;
+
+            for (size_t i = 0; i < f1.size(); i++) {
+                // check for nans
+                if (std::isnan(f1[i]) || std::isnan(f2[i])) {
+                    printf("NaN at index %zu ", i);
+                    ud->ok = false;
+                    return true;
+                }
+                // check for infs: both must be inf of the same sign, or both must be finite
+                if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
+                    if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
+                        if (std::signbit(f1[i]) != std::signbit(f2[i])) {
+                            printf("inf sign mismatch: %f %f ", f1[i], f2[i]);
+                            ud->ok = false;
+                            return true;
+                        }
+                    } else {
+                        printf("inf mismatch: %f %f ", f1[i], f2[i]);
+                        ud->ok = false;
+                        return true;
+                    }
+                }
+            }
+
+            double err = nmse(f1.data(), f2.data(), f1.size());
+            if (err > ud->max_err) {
+                printf("NMSE = %f ", err);
+                ud->ok = false;
+            }
+            return true;
+        };
+
+        ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
+
+        if (ud.ok) {
+            printf("\033[1;32mOK\033[0m\n");
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+
+        ggml_backend_buffer_free(buf);
+
+        ggml_free(ctx);
+
+        return ud.ok;
+    }
+
+    bool eval_perf(ggml_backend_t backend, const char * op_name) {
+        static const size_t graph_nodes = 8192;
+
+        ggml_init_params params = {
+            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false),
+            /* .mem_base = */ NULL,
+            /* .no_alloc = */ true,
+        };
+        ggml_context * ctx = ggml_init(params);
+
+        ggml_tensor * out = build_graph(ctx);
+
+        if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
+            //printf("  %s: skipping\n", ggml_op_desc(out));
+            ggml_free(ctx);
+            return true;
+        }
+
+        int len = printf("  %s(%s): ", ggml_op_desc(out), vars().c_str());
+        fflush(stdout);
+
+        // check if backends support op
+        if (!ggml_backend_supports_op(backend, out)) {
+            printf("not supported\n");
+            ggml_free(ctx);
+            return true;
+        }
+
+        // align while also leaving some margin for variations in parameters
+        int align = 20;
+        int last = (len + align - 1) / align * align;
+        if (last - len < 5) {
+            last += align;
+        }
+        last = std::max(last, 60);
+        printf("%*s", last - len, "");
+
+        // allocate
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
+
+        // randomize tensors
+        initialize_tensors(ctx);
+
+        // build graph
+        ggml_cgraph * gf = ggml_new_graph_custom(ctx, graph_nodes, false);
+        ggml_build_forward_expand(gf, out);
+
+        // warmup run
+        ggml_backend_graph_compute(backend, gf);
+
+        // duplicate the op
+        size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
+        int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
+        for (int i = 1; i < n_runs; i++) {
+            gf->nodes[gf->n_nodes++] = out;
+        }
+
+        // calculate memory
+        size_t mem = n_runs * op_size(out);
+        auto tensor_op_size = [](ggml_tensor * t) {
+            size_t size = ggml_nbytes(t);
+            // add source tensors
+            for (int i = 0; i < GGML_MAX_SRC; i++) {
+                if (t->src[i] != NULL) {
+                    size += ggml_nbytes(t->src[i]);
+                }
+            }
+            return size;
+        };
+        for (int i = 0; i < gf->n_nodes; i++) {
+            if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out)
+                continue;
+            mem += tensor_op_size(gf->nodes[i]);
+        }
+
+        // run
+        ggml_backend_synchronize(backend);
+
+        int64_t start_time = ggml_time_us();
+        ggml_backend_graph_compute(backend, gf);
+        ggml_backend_synchronize(backend);
+        int64_t end_time = ggml_time_us();
+        double time_us = end_time - start_time;
+
+        printf("    %5d runs - %8.2f us/run - %8zu kB/run - \033[1;34m%7.2f GB/s\033[0m\n",
+            n_runs,
+            time_us / n_runs,
+            op_size(out) / 1024,
+            mem / (time_us/1e6) / 1024.0 / 1024.0 / 1024.0);
+
+        ggml_backend_buffer_free(buf);
+
+        ggml_free(ctx);
+
+        return true;
+    }
+};
+
+// GGML_OP_UNARY
+struct test_unary : public test_case {
+    const ggml_unary_op op;
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_unary(ggml_unary_op op,
+            ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {128, 10, 10, 10})
+        : op(op), type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * in = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_unary(ctx, in, op);
+        return out;
+    }
+};
+
+// GGML_OP_GET_ROWS
+struct test_get_rows : public test_case {
+    const ggml_type type;
+    const int n; // cols
+    const int m; // rows
+    const int r; // rows to get
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, n, m, r);
+    }
+
+    test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3)
+        : type(type), n(n), m(m), r(r) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m);
+        ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r);
+        ggml_tensor * out = ggml_get_rows(ctx, in, rows);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // rows
+                std::vector<int> data(r);
+                for (int i = 0; i < r; i++) {
+                    data[i] = rand() % m;
+                }
+                ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// GGML_OP_REPEAT
+struct test_repeat : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const std::array<int, 4> nr;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, nr);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) * 2;
+    }
+
+    test_repeat(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            std::array<int, 4> nr = {2, 2, 2, 2})
+        : type(type), ne(ne), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * target = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_repeat(ctx, src, target);
+        return out;
+    }
+};
+
+// GGML_OP_DUP
+struct test_dup : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_dup(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 1})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_dup(ctx, src);
+        return out;
+    }
+};
+
+// GGML_OP_CPY
+struct test_cpy : public test_case {
+    const ggml_type type_src;
+    const ggml_type type_dst;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type_src, type_dst, ne);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) + ggml_nbytes(t->src[0]);
+    }
+
+    test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 1})
+        : type_src(type_src), type_dst(type_dst), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
+        ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne.data());
+        ggml_tensor * out = ggml_cpy(ctx, src, dst);
+        return out;
+    }
+};
+
+// GGML_OP_CONT
+struct test_cont : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_cont(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 1})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        src = ggml_transpose(ctx, src);
+        ggml_tensor * out = ggml_cont(ctx, src);
+
+        return out;
+    }
+};
+
+// GGML_OP_ADD
+// GGML_OP_MUL
+// GGML_OP_DIV
+struct test_bin_bcast : public test_case {
+    using op_t = ggml_tensor * (*) (ggml_context *, ggml_tensor *, ggml_tensor *);
+    op_t op;
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const std::array<int, 4> nr;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, nr);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) * 3;
+    }
+
+    test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 1, 1},
+            std::array<int, 4> nr = {1, 2, 1, 1})
+        : op(op), type(type), ne(ne), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = op(ctx, a, b);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (op == ggml_div) {
+                // avoid division by zero
+                init_tensor_uniform(t, 1.0f, 2.0f);
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// GGML_OP_SCALE
+struct test_scale : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_scale(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * scale = ggml_new_tensor_1d(ctx, type, 1);
+        ggml_tensor * out = ggml_scale(ctx, a, scale);
+        return out;
+    }
+};
+
+// GGML_OP_NORM
+struct test_norm : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, eps);
+    }
+
+    test_norm(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {64, 10, 10, 10},
+            float eps = 1e-6f)
+        : type(type), ne(ne), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_norm(ctx, a, eps);
+        return out;
+    }
+};
+
+// GGML_OP_RMS_NORM
+struct test_rms_norm : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, eps);
+    }
+
+    test_rms_norm(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {64, 10, 10, 10},
+            float eps = 1e-6f)
+        : type(type), ne(ne), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
+        return out;
+    }
+};
+
+// GGML_OP_MUL_MAT
+struct test_mul_mat : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const std::array<int64_t, 2> bs; // dims 3 and 4
+    const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
+
+    std::string vars() override {
+        return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        size_t a = ggml_nbytes(t->src[0]) * n * nr[0] * nr[1];
+        size_t b = ggml_nbytes(t->src[1]) * m;
+        size_t c  = ggml_nbytes(t);
+        return a + b + c;
+
+        GGML_UNUSED(t);
+    }
+
+    test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32,
+            std::array<int64_t, 2> bs = {10, 10},
+            std::array<int64_t, 2> nr = {2, 2})
+        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]      , bs[1]);
+        ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
+        ggml_tensor * out = ggml_mul_mat(ctx, a, b);
+        return out;
+    }
+};
+
+// GGML_OP_MUL_MAT_ID
+struct test_mul_mat_id : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int n_mats;
+    const int id;
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const std::array<int64_t, 2> bs; // dims 3 and 4
+    const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
+
+    std::string vars() override {
+        return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        size_t a = ggml_nbytes(t->src[2]) * n * nr[0] * nr[1];
+        size_t b = ggml_nbytes(t->src[1]) * m;
+        size_t c  = ggml_nbytes(t);
+        return a + b + c;
+
+        GGML_UNUSED(t);
+    }
+
+    test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int n_mats = 2, int id = 0,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32,
+            std::array<int64_t, 2> bs = {10, 10},
+            std::array<int64_t, 2> nr = {2, 2})
+        : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
+            m(m), n(n), k(k), bs(bs), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
+        std::vector<ggml_tensor *> mats;
+        for (int i = 0; i < n_mats; i++) {
+            ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
+            mats.push_back(a);
+        }
+        ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats);
+        ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
+        ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // ids
+                std::vector<int> data(n_mats);
+                for (int i = 0; i < n_mats; i++) {
+                    data[i] = i;
+                }
+                std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()()));
+                ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// GGML_OP_SQR
+struct test_sqr : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sqr(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_sqr(ctx, a);
+        return out;
+    }
+};
+
+// GGML_OP_CLAMP
+struct test_clamp : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    float min;
+    float max;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, min, max);
+    }
+
+    test_clamp(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            float min = -0.5f, float max = 0.5f)
+        : type(type), ne(ne), min(min), max(max) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_clamp(ctx, a, min, max);
+        return out;
+    }
+};
+
+// GGML_OP_DIAG_MASK_INF
+struct test_diag_mask_inf : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const int n_past;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, n_past);
+    }
+
+    test_diag_mask_inf(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            int n_past = 5)
+        : type(type), ne(ne), n_past(n_past) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past);
+        return out;
+    }
+};
+
+// GGML_OP_SOFT_MAX
+struct test_soft_max : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_soft_max(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_soft_max(ctx, a);
+        return out;
+    }
+};
+
+// GGML_OP_ROPE
+struct test_rope : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    int n_dims;
+    int mode;
+    int n_ctx;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, ne, n_dims, mode, n_ctx);
+    }
+
+    test_rope(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 1},
+            int n_dims = 10, int mode = 0, int n_ctx = 512)
+        : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
+        ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // pos
+                std::vector<int> data(ne[2]);
+                for (int i = 0; i < ne[2]; i++) {
+                    data[i] = rand() % n_ctx;
+                }
+                ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// GGML_OP_ALIBI
+struct test_alibi : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    int n_past;
+    int n_head;
+    float bias_max;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, ne, n_past, n_head, bias_max);
+    }
+
+    test_alibi(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            int n_past = 512, int n_head = 10, float bias_max = 0.5f)
+        : type(type), ne(ne), n_past(n_past), n_head(n_head), bias_max(bias_max) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_alibi(ctx, a, n_past, n_head, bias_max);
+        return out;
+    }
+};
+
+// GGML_OP_IM2COL
+struct test_im2col : public test_case {
+    const ggml_type type_input;
+    const ggml_type type_kernel;
+    const std::array<int64_t, 4> ne_input;
+    const std::array<int64_t, 4> ne_kernel;
+    // stride
+    const int s0;
+    const int s1;
+    // padding
+    const int p0;
+    const int p1;
+    // dilatation
+    const int d0;
+    const int d1;
+    // mode
+    const bool is_2D;
+
+    std::string vars() override {
+        return VARS_TO_STR11(type_input, type_kernel, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
+    }
+
+    test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16,
+            std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
+            std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
+            int s0 = 1, int s1 = 1,
+            int p0 = 1, int p1 = 1,
+            int d0 = 1, int d1 = 1,
+            bool is_2D = true)
+        : type_input(type_input), type_kernel(type_kernel), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
+        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
+        ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D);
+        return out;
+    }
+};
+
+// GGML_OP_CONCAT
+struct test_concat : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const int64_t b_ne2;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, b_ne2);
+    }
+
+    test_concat(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            int64_t b_ne2 = 10)
+        : type(type), ne(ne), b_ne2(b_ne2) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], b_ne2, ne[3]);
+        ggml_tensor * out = ggml_concat(ctx, a, b);
+        return out;
+    }
+};
+
+// GGML_OP_ARGSORT
+struct test_argsort : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    ggml_sort_order order;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, order);
+    }
+
+    test_argsort(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {16, 10, 10, 10},
+            ggml_sort_order order = GGML_SORT_ASC)
+        : type(type), ne(ne), order(order) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_argsort(ctx, a, order);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        std::random_device rd;
+        std::default_random_engine rng(rd());
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // indices
+                std::vector<int> data(ggml_nelements(t));
+                for (int i = 0; i < ggml_nelements(t); i++) {
+                    data[i] = rand();
+                }
+                std::shuffle(data.begin(), data.end(), rng);
+                ggml_backend_tensor_set(t, data.data(), 0, ne[0]*ne[1]*ne[2]*ne[3] * sizeof(int));
+            } else if (t->type == GGML_TYPE_F32) {
+                // initialize with unique values to avoid ties
+                for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                    std::vector<float> data(t->ne[0]);
+                    for (int i = 0; i < t->ne[0]; i++) {
+                        data[i] = i;
+                    }
+                    std::shuffle(data.begin(), data.end(), rng);
+                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
+                }
+            } else {
+                GGML_ASSERT(false);
+            }
+        }
+    }
+};
+
+// GGML_OP_SUM_ROWS
+struct test_sum_rows : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sum_rows(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_sum_rows(ctx, a);
+        return out;
+    }
+};
+
+enum test_mode {
+    MODE_TEST,
+    MODE_PERF,
+};
+
+static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
+    std::vector<std::unique_ptr<test_case>> test_cases;
+
+    // unary ops
+    for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
+        test_cases.emplace_back(new test_unary((ggml_unary_op) op));
+    }
+
+    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+        test_cases.emplace_back(new test_get_rows(type, 10, 5, 3));
+        test_cases.emplace_back(new test_get_rows(type, 16, 5, 3));
+    }
+
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 2, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2}));
+
+    test_cases.emplace_back(new test_dup());
+    test_cases.emplace_back(new test_cpy());
+    test_cases.emplace_back(new test_cont());
+
+    auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
+        for (auto op : {ggml_add, ggml_mul, ggml_div}) {
+            test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
+        }
+    };
+
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2});
+
+    // stable diffusion
+    add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 16, 16, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1280, 16, 16, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 256, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {16, 16, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 16, 1280, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {16, 16, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 2560, 1}, {16, 16, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {32, 32, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {32, 32, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
+
+    test_cases.emplace_back(new test_scale());
+
+    for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) {
+        test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
+        test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
+    }
+
+    const ggml_type all_types[] = {
+        GGML_TYPE_F32, GGML_TYPE_F16,
+        GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
+        GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+        GGML_TYPE_Q8_0,
+        GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+        GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
+        GGML_TYPE_Q6_K
+    };
+
+    for (ggml_type type_a : all_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
+            // FIXME: CPU crashes on f16xf16
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1,  1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
+
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1,  1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
+        }
+    }
+
+    for (ggml_type type_a : all_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
+            for (int n_mats : {1, 2, 4}) {
+                for (int id = 0; id < n_mats; id++) {
+                    test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1}));
+                }
+            }
+        }
+    }
+
+    test_cases.emplace_back(new test_sqr());
+    test_cases.emplace_back(new test_clamp());
+
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10,  1,  1}, 5));
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10,  1}, 5));
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
+
+    test_cases.emplace_back(new test_soft_max());
+
+    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+        test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512)); // llama 7B
+        test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512)); // llama 13B
+        test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512)); // llama 30B
+        test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512)); // llama 65B
+        test_cases.emplace_back(new test_rope(type, { 64,   1, 10, 1},  64, 2, 512)); // neox (falcon 7B)
+        test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512)); // neox (falcon 7B)
+        test_cases.emplace_back(new test_rope(type, { 64,   8, 10, 1},  64, 2, 512)); // neox (falcon 40B)
+        test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512)); // neox (falcon 40B)
+        test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  20, 2, 512)); // neox (stablelm)
+    }
+
+    test_cases.emplace_back(new test_alibi());
+    test_cases.emplace_back(new test_im2col());
+    test_cases.emplace_back(new test_concat());
+
+    for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
+    }
+
+    test_cases.emplace_back(new test_sum_rows());
+
+    // run tests
+    if (mode == MODE_TEST) {
+        ggml_backend_t backend_cpu = ggml_backend_cpu_init();
+
+        size_t n_ok = 0;
+        for (auto & test : test_cases) {
+            if (test->eval(backend, backend_cpu, op_name)) {
+                n_ok++;
+            }
+        }
+        printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
+
+        ggml_backend_free(backend_cpu);
+
+        return n_ok == test_cases.size();
+    } else if (mode == MODE_PERF) {
+        for (auto & test : test_cases) {
+            test->eval_perf(backend, op_name);
+        }
+        return true;
+    } else {
+        GGML_ASSERT(false);
+    }
+}
+
+static void usage(char ** argv) {
+    printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
+    printf("  valid modes are: test (compare with CPU backend for correctness) or perf (performance evaluation)\n");
+    printf("  op names are as given by ggml_op_desc()\n");
+}
+
+int main(int argc, char ** argv) {
+    test_mode mode = MODE_TEST;
+    const char * op_name = NULL;
+    const char * backend = NULL;
+
+    for (int i = 1; i < argc; i++) {
+        if (strcmp(argv[i], "test") == 0) {
+            mode = MODE_TEST;
+        } else if (strcmp(argv[i], "perf") == 0) {
+            mode = MODE_PERF;
+        } else if (strcmp(argv[i], "-o") == 0) {
+            if (i + 1 < argc) {
+                op_name = argv[++i];
+            } else {
+                usage(argv);
+                return 1;
+            }
+        } else if (strcmp(argv[i], "-b") == 0) {
+            if (i + 1 < argc) {
+                backend = argv[++i];
+            } else {
+                usage(argv);
+                return 1;
+            }
+        } else {
+            usage(argv);
+            return 1;
+        }
+    }
+
+    // enumerate backends
+    printf("Testing %zu backends\n\n", ggml_backend_reg_get_count());
+
+    size_t n_ok = 0;
+
+    for (size_t i = 0; i < ggml_backend_reg_get_count(); i++) {
+        printf("Backend %zu/%zu (%s)\n", i + 1, ggml_backend_reg_get_count(), ggml_backend_reg_get_name(i));
+
+        if (backend != NULL && strcmp(backend, ggml_backend_reg_get_name(i)) != 0) {
+            printf("  Skipping\n");
+            n_ok++;
+            continue;
+        }
+
+        ggml_backend_t backend = ggml_backend_reg_init_backend(i, NULL);
+        GGML_ASSERT(backend != NULL);
+        printf("  Backend name: %s\n", ggml_backend_name(backend));
+
+        bool ok = test_backend(backend, mode, op_name);
+
+        printf("  Backend %s: ", ggml_backend_name(backend));
+        if (ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            n_ok++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+
+        printf("\n");
+
+        ggml_backend_free(backend);
+    }
+
+    printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count());
+    if (n_ok != ggml_backend_reg_get_count()) {
+        printf("\033[1;31mFAIL\033[0m\n");
+        return 1;
+    } else {
+        printf("\033[1;32mOK\033[0m\n");
+        return 0;
+    }
+}

部分文件因文件數量過多而無法顯示