Переглянути джерело

sync : ggml (#5452)

* ggml-alloc : v3 (ggml/727)

* ggml-alloc v3

ggml-ci

* fix ci

ggml-ci

* whisper : check for backend buffer allocation failures

* whisper : avoid leaks when initialization fails

* cleanup

ggml-ci

* style fixes

ggml-ci

* sync : ggml

* update llama.cpp, clip.cpp, export-lora.cpp

* update finetune.cpp, train-text-from-scratch.cpp

ggml-ci

* ggml-backend : reduce alignment to 32 to match gguf and fix mmap

---------

Co-authored-by: slaren <slarengh@gmail.com>
Georgi Gerganov 1 рік тому
батько
коміт
3b169441df

+ 5 - 14
examples/export-lora/export-lora.cpp

@@ -337,24 +337,14 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int
     params.mem_buffer = NULL;
     params.no_alloc   = true;
     struct ggml_context * ctx = NULL;
-    struct ggml_allocr * alloc = NULL;
-    struct ggml_cgraph * gf = NULL;
+    struct ggml_gallocr * alloc = NULL;
+    struct ggml_cgraph  * gf = NULL;
 
     ctx   = ggml_init(params);
-    alloc = ggml_allocr_new_measure(tensor_alignment);
+    alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
     gf    = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
-    size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf);
-    ggml_allocr_free(alloc);
-    ggml_free(ctx);
-
-    static std::vector<uint8_t> data_compute;
-    data_compute.resize(alloc_size + tensor_alignment);
 
-    ctx   = ggml_init(params);
-    alloc = ggml_allocr_new(data_compute.data(), data_compute.size(), tensor_alignment);
-    gf    = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
-    ggml_allocr_alloc_graph(alloc, gf);
-    ggml_allocr_free(alloc);
+    ggml_gallocr_alloc_graph(alloc, gf);
 
     struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads);
     static std::vector<uint8_t> data_work;
@@ -363,6 +353,7 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int
 
     ggml_graph_compute(gf, &cplan);
 
+    ggml_gallocr_free(alloc);
     ggml_free(ctx);
     return true;
 }

+ 37 - 108
examples/finetune/finetune.cpp

@@ -1,5 +1,6 @@
 #include "ggml.h"
 #include "ggml-alloc.h"
+#include "ggml-backend.h"
 #include "llama.h"
 #include "common.h"
 #include "train.h"
@@ -13,8 +14,6 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-static const size_t tensor_alignment = 32;
-
 struct my_llama_hparams {
     uint32_t n_vocab    = 32000;
     uint32_t n_ctx      = 512;
@@ -128,7 +127,7 @@ struct my_llama_lora_layer {
 
 struct my_llama_lora {
     struct ggml_context * ctx = NULL;
-    std::vector<uint8_t> data;
+    ggml_backend_buffer_t data;
 
     my_llama_lora_hparams hparams;
 
@@ -372,63 +371,6 @@ static void set_param_lora(struct my_llama_lora * lora) {
     }
 }
 
-static void alloc_lora(struct ggml_allocr * alloc, struct my_llama_lora * lora) {
-    ggml_allocr_alloc(alloc, lora->tok_embeddings_a);
-    ggml_allocr_alloc(alloc, lora->tok_embeddings_b);
-    ggml_allocr_alloc(alloc, lora->norm_a);
-    ggml_allocr_alloc(alloc, lora->norm_b);
-    ggml_allocr_alloc(alloc, lora->output_a);
-    ggml_allocr_alloc(alloc, lora->output_b);
-    for (uint32_t i = 0; i < lora->layers.size(); ++i) {
-        auto & layer = lora->layers[i];
-        ggml_allocr_alloc(alloc, layer.attention_norm_a);
-        ggml_allocr_alloc(alloc, layer.attention_norm_b);
-        ggml_allocr_alloc(alloc, layer.wq_a);
-        ggml_allocr_alloc(alloc, layer.wq_b);
-        ggml_allocr_alloc(alloc, layer.wk_a);
-        ggml_allocr_alloc(alloc, layer.wk_b);
-        ggml_allocr_alloc(alloc, layer.wv_a);
-        ggml_allocr_alloc(alloc, layer.wv_b);
-        ggml_allocr_alloc(alloc, layer.wo_a);
-        ggml_allocr_alloc(alloc, layer.wo_b);
-        ggml_allocr_alloc(alloc, layer.ffn_norm_a);
-        ggml_allocr_alloc(alloc, layer.ffn_norm_b);
-        ggml_allocr_alloc(alloc, layer.w1_a);
-        ggml_allocr_alloc(alloc, layer.w1_b);
-        ggml_allocr_alloc(alloc, layer.w2_a);
-        ggml_allocr_alloc(alloc, layer.w2_b);
-        ggml_allocr_alloc(alloc, layer.w3_a);
-        ggml_allocr_alloc(alloc, layer.w3_b);
-    }
-    ggml_allocr_alloc(alloc, lora->tok_embeddings_a->grad);
-    ggml_allocr_alloc(alloc, lora->tok_embeddings_b->grad);
-    ggml_allocr_alloc(alloc, lora->norm_a->grad);
-    ggml_allocr_alloc(alloc, lora->norm_b->grad);
-    ggml_allocr_alloc(alloc, lora->output_a->grad);
-    ggml_allocr_alloc(alloc, lora->output_b->grad);
-    for (uint32_t i = 0; i < lora->layers.size(); ++i) {
-        auto & layer = lora->layers[i];
-        ggml_allocr_alloc(alloc, layer.attention_norm_a->grad);
-        ggml_allocr_alloc(alloc, layer.attention_norm_b->grad);
-        ggml_allocr_alloc(alloc, layer.wq_a->grad);
-        ggml_allocr_alloc(alloc, layer.wq_b->grad);
-        ggml_allocr_alloc(alloc, layer.wk_a->grad);
-        ggml_allocr_alloc(alloc, layer.wk_b->grad);
-        ggml_allocr_alloc(alloc, layer.wv_a->grad);
-        ggml_allocr_alloc(alloc, layer.wv_b->grad);
-        ggml_allocr_alloc(alloc, layer.wo_a->grad);
-        ggml_allocr_alloc(alloc, layer.wo_b->grad);
-        ggml_allocr_alloc(alloc, layer.ffn_norm_a->grad);
-        ggml_allocr_alloc(alloc, layer.ffn_norm_b->grad);
-        ggml_allocr_alloc(alloc, layer.w1_a->grad);
-        ggml_allocr_alloc(alloc, layer.w1_b->grad);
-        ggml_allocr_alloc(alloc, layer.w2_a->grad);
-        ggml_allocr_alloc(alloc, layer.w2_b->grad);
-        ggml_allocr_alloc(alloc, layer.w3_a->grad);
-        ggml_allocr_alloc(alloc, layer.w3_b->grad);
-    }
-}
-
 static void init_lora(const struct my_llama_model * model, struct my_llama_lora * lora) {
     const auto & lparams = lora->hparams;
 
@@ -522,18 +464,8 @@ static void init_lora(const struct my_llama_model * model, struct my_llama_lora
 
     set_param_lora(lora);
 
-    // measure data size
-    size_t size = 0;
-    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-        size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
-    }
-
-    // allocate data
-    struct ggml_allocr * alloc = NULL;
-    lora->data.resize(size + tensor_alignment);
-    alloc = ggml_allocr_new(lora->data.data(), lora->data.size(), tensor_alignment);
-    alloc_lora(alloc, lora);
-    ggml_allocr_free(alloc);
+    // allocate data for lora tensors
+    lora->data = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type());
 }
 
 static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, float std, float min, float max) {
@@ -579,7 +511,7 @@ static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, fl
 static struct ggml_tensor * llama_build_lora_finetune_graphs(
         struct my_llama_model * model,
         struct my_llama_lora  * lora,
-        struct ggml_allocr    * alloc,
+        ggml_gallocr_t          alloc,
         struct ggml_context   * ctx,
         struct ggml_cgraph    * gf,
         struct ggml_cgraph    * gb,
@@ -590,7 +522,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
         const  int              n_tokens,
         const  int              n_batch,
         const  bool             enable_flash_attn,
-        const  bool             enable_checkpointing) {
+        const  bool             enable_checkpointing,
+        const  bool             measure_only) {
 
     ggml_set_scratch(ctx, { 0, 0, nullptr, });
     const int n_past = 0;
@@ -622,13 +555,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
 
     // KQ_pos - contains the positions
     struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
-    ggml_allocr_alloc(alloc, KQ_pos);
-    if (!ggml_allocr_is_measure(alloc)) {
-        int * data = (int *) KQ_pos->data;
-        for (int i = 0; i < N; ++i) {
-            data[i] = n_past + i;
-        }
-    }
+    ggml_set_input(KQ_pos);
 
     // rope has so much parameters that we make a custom function for it
     auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
@@ -780,7 +707,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
     // input gradient
     ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
     GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
-    ggml_allocr_alloc(alloc, t36->grad);
+    ggml_set_input(t36->grad);
     // KQ_pos
     ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
 
@@ -805,11 +732,23 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
     // note: they will be freed in reverse order
     for (unsigned int i = 0; i < checkpoints.size(); ++i) {
         if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
-            ggml_allocr_alloc(alloc, checkpoints[i]);
+            ggml_set_input(checkpoints[i]);
         }
     }
 
-    ggml_allocr_alloc_graph(alloc, gb);
+    if (measure_only) {
+        ggml_gallocr_reserve(alloc, gb);
+    } else {
+        ggml_gallocr_alloc_graph(alloc, gb);
+
+        // set KQ_pos
+        {
+            int * data = (int *) KQ_pos->data;
+            for (int i = 0; i < N; ++i) {
+                data[i] = n_past + i;
+            }
+        }
+    }
 
     // remove the additional nodes and leafs
     for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
@@ -1663,7 +1602,7 @@ int main(int argc, char ** argv) {
     printf("%s: seen train_samples     %llu\n", __func__, (long long unsigned) train->train_samples);
     printf("%s: seen train_tokens      %llu\n", __func__, (long long unsigned) train->train_tokens);
     printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
-    printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + lora.data.size()), (float) (ggml_used_mem(lora.ctx) + lora.data.size()) / (1024.0f*1024.0f));
+    printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + ggml_backend_buffer_get_size(lora.data)), (float) (ggml_used_mem(lora.ctx) + ggml_backend_buffer_get_size(lora.data)) / (1024.0f*1024.0f));
 
     if (params.only_write_lora) {
         save_train_files_data save_data;
@@ -1690,10 +1629,6 @@ int main(int argc, char ** argv) {
     int n_vocab  = model.hparams.n_vocab;
     int n_batch  = params.common.n_batch;
 
-
-    std::vector<uint8_t> mem_input_data;
-    std::vector<uint8_t> mem_compute_data;
-
     // context for input tensors without their data
     struct ggml_init_params ctx_input_params = {
         ggml_tensor_overhead() * 2, // mem_size
@@ -1706,18 +1641,12 @@ int main(int argc, char ** argv) {
     struct ggml_tensor * tokens_input  = ggml_new_tensor_2d(ctx_input, GGML_TYPE_I32, n_tokens, n_batch);
     struct ggml_tensor * target_probs  = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab,  n_tokens, n_batch);
 
+    // allocate input tensors
     // measure required memory for input tensors
-    size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
-                            GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
-                            tensor_alignment;
+    ggml_backend_buffer_t input_data = ggml_backend_alloc_ctx_tensors_from_buft(ctx_input, ggml_backend_cpu_buffer_type());
+    size_t max_input_size = ggml_backend_buffer_get_size(input_data);
     printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
 
-    // allocate input tensors
-    mem_input_data.resize(max_input_size);
-    ggml_allocr_t alloc_inps = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
-    ggml_allocr_alloc(alloc_inps, tokens_input);
-    ggml_allocr_alloc(alloc_inps, target_probs);
-
     // context for compute tensors without their data
     const size_t estimated_compute_size_wo_data = (
             2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
@@ -1743,7 +1672,7 @@ int main(int argc, char ** argv) {
     // find best evaluation order
     for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
         ctx_compute = ggml_init(ctx_compute_params);
-        ggml_allocr_t alloc = ggml_allocr_new_measure(tensor_alignment);
+        ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
         gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
         gf->order = (enum ggml_cgraph_eval_order) order;
         gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@@ -1756,14 +1685,15 @@ int main(int argc, char ** argv) {
             &logits, tokens_input, target_probs,
             n_tokens, n_batch,
             params.common.use_flash,
-            params.common.use_checkpointing
+            params.common.use_checkpointing,
+            true
         );
-        size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+        size_t max_compute_size = ggml_gallocr_get_buffer_size(alloc, 0); // FIXME: this will still allocate the buffer
         if (max_compute_size < best_compute_size) {
             best_compute_size = max_compute_size;
             best_order = gf->order;
         }
-        ggml_allocr_free(alloc);
+        ggml_gallocr_free(alloc);
         ggml_free(ctx_compute);
     }
     size_t max_compute_size = best_compute_size;
@@ -1774,9 +1704,8 @@ int main(int argc, char ** argv) {
         "invalid");
 
     // allocate compute tensors
-    mem_compute_data.resize(max_compute_size);
     ctx_compute = ggml_init(ctx_compute_params);
-    ggml_allocr_t alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
+    ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
     gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
     gf->order = best_order;
     gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@@ -1789,11 +1718,9 @@ int main(int argc, char ** argv) {
         &logits, tokens_input, target_probs,
         n_tokens, n_batch,
         params.common.use_flash,
-        params.common.use_checkpointing
+        params.common.use_checkpointing,
+        false
     );
-    ggml_allocr_free(alloc);
-    ggml_allocr_free(alloc_inps);
-
 
     // tokenize data
     std::vector<llama_token> train_tokens;
@@ -1908,6 +1835,8 @@ int main(int argc, char ** argv) {
     ggml_free(ctx_work);
     ggml_free(ctx_compute);
     ggml_free(ctx_input);
+    ggml_gallocr_free(alloc);
+
 
     int64_t t1 = ggml_time_ms();
     printf("%s: total training time: ", __func__);

+ 81 - 71
examples/llava/clip.cpp

@@ -367,7 +367,7 @@ struct clip_ctx {
     ggml_backend_buffer_t params_buffer = NULL;
     ggml_backend_buffer_t compute_buffer = NULL;
     ggml_backend_t backend = NULL;
-    ggml_allocr * compute_alloc = NULL;
+    ggml_gallocr_t compute_alloc = NULL;
 };
 
 static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
@@ -405,31 +405,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
     struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);
-    ggml_allocr_alloc(ctx->compute_alloc, inp_raw);
-
-    if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
-        float * data = (float *)malloc(ggml_nbytes(inp_raw));
-
-        for (size_t i = 0; i < imgs->size; i++) {
-            const int nx = imgs->data[i].nx;
-            const int ny = imgs->data[i].ny;
-            GGML_ASSERT(nx == image_size && ny == image_size);
-
-            const int n = nx * ny;
-
-            for (int b = 0; b < batch_size; b++) {
-                for (int k = 0; k < 3; k++) {
-                    for (int y = 0; y < ny; y++) {
-                        for (int x = 0; x < nx; x++) {
-                            data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
-                        }
-                    }
-                }
-            }
-        }
-        ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
-        free(data);
-    }
+    ggml_set_name(inp_raw, "inp_raw");
+    ggml_set_input(inp_raw);
 
     struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
 
@@ -438,13 +415,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
 
     // concat class_embeddings and patch_embeddings
     struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
-    ggml_allocr_alloc(ctx->compute_alloc, embeddings);
-    if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
-        void* zero_mem = malloc(ggml_nbytes(embeddings));
-        memset(zero_mem, 0, ggml_nbytes(embeddings));
-        ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
-        free(zero_mem);
-    }
+    ggml_set_name(embeddings, "embeddings");
+    ggml_set_input(embeddings);
 
     embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
             embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
@@ -453,15 +425,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
 
     struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
-    ggml_allocr_alloc(ctx->compute_alloc, positions);
-    if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
-        int* positions_data = (int*)malloc(ggml_nbytes(positions));
-        for (int i = 0; i < num_positions; i++) {
-            positions_data[i] = i;
-        }
-        ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
-        free(positions_data);
-    }
+    ggml_set_name(positions, "positions");
+    ggml_set_input(positions);
 
     embeddings =
         ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
@@ -560,15 +525,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
         embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
 
         struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
-        ggml_allocr_alloc(ctx->compute_alloc, patches);
-        if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
-            int* patches_data = (int*)malloc(ggml_nbytes(patches));
-            for (int i = 0; i < num_patches; i++) {
-                patches_data[i] = i + 1;
-            }
-            ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
-            free(patches_data);
-        }
+        ggml_set_name(patches, "patches");
+        ggml_set_input(patches);
 
         // shape [1, 576, 1024]
         // ne is whcn, ne = [1024, 576, 1, 1]
@@ -809,7 +767,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
     }
 
     // data
-    size_t buffer_size = 0;
+    size_t model_size = 0;
     {
         for (int i = 0; i < n_tensors; ++i) {
             const char * name = gguf_get_tensor_name(ctx, i);
@@ -817,7 +775,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
             enum ggml_type type = gguf_get_tensor_type(ctx, i);
             struct ggml_tensor * cur = ggml_get_tensor(meta, name);
             size_t tensor_size = ggml_nbytes(cur);
-            buffer_size += tensor_size;
+            model_size += tensor_size;
             if (verbosity >= 3) {
                 printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n",
                        __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type));
@@ -825,8 +783,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
         }
     }
 
-    buffer_size += n_tensors * 128 /* CLIP PADDING */;
-
     clip_ctx * new_clip = new clip_ctx;
 
     // update projector type
@@ -886,12 +842,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
             printf("%s: text_encoder:   %d\n", __func__, new_clip->has_text_encoder);
             printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
             printf("%s: llava_projector:  %d\n", __func__, new_clip->has_llava_projector);
-            printf("%s: model size:     %.2f MB\n", __func__, buffer_size / 1024.0 / 1024.0);
+            printf("%s: model size:     %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
             printf("%s: metadata size:  %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
         }
     }
 
-    printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, buffer_size / (1024.0 * 1024.0), n_tensors);
+    printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, model_size / (1024.0 * 1024.0), n_tensors);
 
     // load tensors
     {
@@ -925,12 +881,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
         }
 
         // alloc memory and offload data
-        new_clip->params_buffer = ggml_backend_alloc_buffer(new_clip->backend, buffer_size);
-        ggml_allocr* alloc = ggml_allocr_new_from_buffer(new_clip->params_buffer);
+        new_clip->params_buffer = ggml_backend_alloc_ctx_tensors(new_clip->ctx_data, new_clip->backend);
         for (int i = 0; i < n_tensors; ++i) {
             const char * name = gguf_get_tensor_name(ctx, i);
             struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name);
-            ggml_allocr_alloc(alloc, cur);
             const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
             fin.seekg(offset, std::ios::beg);
             if (!fin) {
@@ -949,7 +903,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
                 ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
             }
         }
-        ggml_allocr_free(alloc);
         fin.close();
     }
 
@@ -1077,15 +1030,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
     // measure mem requirement and allocate
     {
         new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
-        new_clip->compute_alloc = ggml_allocr_new_measure_from_backend(new_clip->backend);
+        new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
         clip_image_f32_batch batch;
         batch.size = 1;
         ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch);
-        size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(new_clip->compute_alloc, gf);
-        ggml_allocr_free(new_clip->compute_alloc);
-        new_clip->compute_buffer = ggml_backend_alloc_buffer(new_clip->backend, compute_memory_buffer_size);
-        new_clip->compute_alloc = ggml_allocr_new_from_buffer(new_clip->compute_buffer);
-
+        ggml_gallocr_reserve(new_clip->compute_alloc, gf);
+        size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
         printf("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
     }
 
@@ -1267,12 +1217,72 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         GGML_ASSERT(batch_size == 1); // TODO: support multiple images
     }
 
-    // reset alloc buffer to clean the memory from previous invocations
-    ggml_allocr_reset(ctx->compute_alloc);
-
     // build the inference graph
     ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
-    ggml_allocr_alloc_graph(ctx->compute_alloc, gf);
+    ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
+
+    // set inputs
+    const auto & model = ctx->vision_model;
+    const auto & hparams = model.hparams;
+    const int image_size = hparams.image_size;
+    const int patch_size = hparams.patch_size;
+    const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
+    const int num_positions = num_patches + 1;
+
+    {
+        struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
+        float * data = (float *)malloc(ggml_nbytes(inp_raw));
+
+        for (size_t i = 0; i < imgs->size; i++) {
+            const int nx = imgs->data[i].nx;
+            const int ny = imgs->data[i].ny;
+            GGML_ASSERT(nx == image_size && ny == image_size);
+
+            const int n = nx * ny;
+
+            for (int b = 0; b < batch_size; b++) {
+                for (int k = 0; k < 3; k++) {
+                    for (int y = 0; y < ny; y++) {
+                        for (int x = 0; x < nx; x++) {
+                            data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
+                        }
+                    }
+                }
+            }
+        }
+        ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
+        free(data);
+    }
+
+    {
+        struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
+
+        void* zero_mem = malloc(ggml_nbytes(embeddings));
+        memset(zero_mem, 0, ggml_nbytes(embeddings));
+        ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
+        free(zero_mem);
+    }
+
+    {
+        struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+
+        int* positions_data = (int*)malloc(ggml_nbytes(positions));
+        for (int i = 0; i < num_positions; i++) {
+            positions_data[i] = i;
+        }
+        ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+        free(positions_data);
+    }
+
+    {
+        struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
+        int* patches_data = (int*)malloc(ggml_nbytes(patches));
+        for (int i = 0; i < num_patches; i++) {
+            patches_data[i] = i + 1;
+        }
+        ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
+        free(patches_data);
+    }
 
     if (ggml_backend_is_cpu(ctx->backend)) {
         ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);

+ 32 - 80
examples/train-text-from-scratch/train-text-from-scratch.cpp

@@ -1,5 +1,6 @@
 #include "ggml.h"
 #include "ggml-alloc.h"
+#include "ggml-backend.h"
 #include "common.h"
 #include "train.h"
 #include "llama.h"
@@ -19,8 +20,6 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-static const size_t tensor_alignment = 32;
-
 struct my_llama_hparams {
     uint32_t n_vocab = 32000;
     uint32_t n_ctx   = 512;
@@ -58,7 +57,7 @@ struct my_llama_layer {
 
 struct my_llama_model {
     struct ggml_context * ctx = NULL;
-    std::vector<uint8_t> data;
+    ggml_backend_buffer_t data = NULL;
 
     my_llama_hparams hparams;
 
@@ -147,39 +146,6 @@ static void set_param_model(struct my_llama_model * model) {
     }
 }
 
-static void alloc_model(struct ggml_allocr * alloc, struct my_llama_model * model) {
-    ggml_allocr_alloc(alloc, model->tok_embeddings);
-    ggml_allocr_alloc(alloc, model->norm);
-    ggml_allocr_alloc(alloc, model->output);
-    for (uint32_t i = 0; i < model->layers.size(); ++i) {
-        auto & layer = model->layers[i];
-        ggml_allocr_alloc(alloc, layer.attention_norm);
-        ggml_allocr_alloc(alloc, layer.wq);
-        ggml_allocr_alloc(alloc, layer.wk);
-        ggml_allocr_alloc(alloc, layer.wv);
-        ggml_allocr_alloc(alloc, layer.wo);
-        ggml_allocr_alloc(alloc, layer.ffn_norm);
-        ggml_allocr_alloc(alloc, layer.w1);
-        ggml_allocr_alloc(alloc, layer.w2);
-        ggml_allocr_alloc(alloc, layer.w3);
-    }
-    ggml_allocr_alloc(alloc, model->tok_embeddings->grad);
-    ggml_allocr_alloc(alloc, model->norm->grad);
-    ggml_allocr_alloc(alloc, model->output->grad);
-    for (uint32_t i = 0; i < model->layers.size(); ++i) {
-        auto & layer = model->layers[i];
-        ggml_allocr_alloc(alloc, layer.attention_norm->grad);
-        ggml_allocr_alloc(alloc, layer.wq->grad);
-        ggml_allocr_alloc(alloc, layer.wk->grad);
-        ggml_allocr_alloc(alloc, layer.wv->grad);
-        ggml_allocr_alloc(alloc, layer.wo->grad);
-        ggml_allocr_alloc(alloc, layer.ffn_norm->grad);
-        ggml_allocr_alloc(alloc, layer.w1->grad);
-        ggml_allocr_alloc(alloc, layer.w2->grad);
-        ggml_allocr_alloc(alloc, layer.w3->grad);
-    }
-}
-
 static void init_model(struct my_llama_model * model) {
     const auto & hparams = model->hparams;
 
@@ -252,17 +218,8 @@ static void init_model(struct my_llama_model * model) {
 
     set_param_model(model);
 
-    // measure data size
-    size_t size = 0;
-    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-        size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
-    }
-
     // allocate data
-    struct ggml_allocr * alloc = NULL;
-    model->data.resize(size + tensor_alignment);
-    alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
-    alloc_model(alloc, model);
+    model->data = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type());
 }
 
 static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
@@ -297,7 +254,7 @@ static void randomize_model(struct my_llama_model * model, int seed, float mean,
 
 static struct ggml_tensor * llama_build_train_graphs(
         struct my_llama_model * model,
-        struct ggml_allocr    * alloc,
+        ggml_gallocr_t          alloc,
         struct ggml_context   * ctx,
         struct ggml_cgraph    * gf,
         struct ggml_cgraph    * gb,
@@ -308,7 +265,8 @@ static struct ggml_tensor * llama_build_train_graphs(
         const  int              n_tokens,
         const  int              n_batch,
         const  bool             enable_flash_attn,
-        const  bool             enable_checkpointing) {
+        const  bool             enable_checkpointing,
+        const  bool             measure_only) {
 
     ggml_set_scratch(ctx, { 0, 0, nullptr, });
     const int n_past = 0;
@@ -334,13 +292,7 @@ static struct ggml_tensor * llama_build_train_graphs(
 
     // KQ_pos - contains the positions
     struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
-    ggml_allocr_alloc(alloc, KQ_pos);
-    if (!ggml_allocr_is_measure(alloc)) {
-        int * data = (int *) KQ_pos->data;
-        for (int i = 0; i < N; ++i) {
-            data[i] = n_past + i;
-        }
-    }
+    ggml_set_input(KQ_pos);
 
     // rope has so much parameters that we make a custom function for it
     auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
@@ -448,21 +400,31 @@ static struct ggml_tensor * llama_build_train_graphs(
         // KQ_pos
         ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
         GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
-
-        ggml_allocr_alloc(alloc, t36->grad);
+        ggml_set_input(t36->grad);
 
         // allocating checkpoints in one block to reduce memory fragmentation
         // note: they will be freed in reverse order
         for (int i = 0; i < (int) checkpoints.size(); ++i) {
             if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
-                ggml_allocr_alloc(alloc, checkpoints[i]);
+                ggml_set_input(checkpoints[i]);
             }
         }
 
         //int n_leafs_after = gb->n_leafs;
         //int n_nodes_after = gb->n_nodes;
+        if (measure_only) {
+            // FIXME: will still allocate
+            ggml_gallocr_reserve(alloc, gb);
+        } else {
+            ggml_gallocr_alloc_graph(alloc, gb);
 
-        ggml_allocr_alloc_graph(alloc, gb);
+            if (!measure_only) {
+                int * data = (int *) KQ_pos->data;
+                for (int i = 0; i < N; ++i) {
+                    data[i] = n_past + i;
+                }
+            }
+        }
 
         // remove the additional nodes and leafs
         for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
@@ -1046,7 +1008,7 @@ int main(int argc, char ** argv) {
     printf("%s: seen train_samples     %llu\n", __func__, (long long unsigned) train->train_samples);
     printf("%s: seen train_tokens      %llu\n", __func__, (long long unsigned) train->train_tokens);
     printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
-    printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + model.data.size()), (float) (ggml_used_mem(model.ctx) + model.data.size()) / (1024.0f*1024.0f));
+    printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + ggml_backend_buffer_get_size(model.data)), (float) (ggml_used_mem(model.ctx) + ggml_backend_buffer_get_size(model.data)) / (1024.0f*1024.0f));
 
     if (params.only_write_model) {
         save_train_files_data save_data;
@@ -1073,11 +1035,6 @@ int main(int argc, char ** argv) {
     int n_vocab  = model.hparams.n_vocab;
     int n_batch  = params.common.n_batch;
 
-    std::vector<uint8_t> mem_input_data;
-    std::vector<uint8_t> mem_compute_data;
-
-    ggml_allocr * alloc = NULL;
-
     // context for input tensors without their data
     struct ggml_init_params ctx_input_params = {
         ggml_tensor_overhead() * 2, // mem_size
@@ -1091,16 +1048,10 @@ int main(int argc, char ** argv) {
     struct ggml_tensor * target_probs  = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab,  n_tokens, n_batch);
 
     // measure required memory for input tensors
-    size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
-                            GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
-                            tensor_alignment;
-    printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
-
     // allocate input tensors
-    mem_input_data.resize(max_input_size);
-    alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
-    ggml_allocr_alloc(alloc, tokens_input);
-    ggml_allocr_alloc(alloc, target_probs);
+    ggml_backend_buffer_t input_data = ggml_backend_alloc_ctx_tensors_from_buft(ctx_input, ggml_backend_cpu_buffer_type());
+    size_t max_input_size = ggml_backend_buffer_get_size(input_data);
+    printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
 
     // context for compute tensors without their data
     const size_t estimated_compute_size_wo_data = (
@@ -1127,7 +1078,7 @@ int main(int argc, char ** argv) {
     // find best evaluation order
     for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
         ctx_compute = ggml_init(ctx_compute_params);
-        alloc = ggml_allocr_new_measure(tensor_alignment);
+        ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
         gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
         gf->order = (enum ggml_cgraph_eval_order) order;
         gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@@ -1140,9 +1091,10 @@ int main(int argc, char ** argv) {
             &logits, tokens_input, target_probs,
             n_tokens, n_batch,
             params.common.use_flash,
-            params.common.use_checkpointing
+            params.common.use_checkpointing,
+            true
         );
-        size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+        size_t max_compute_size = ggml_gallocr_get_buffer_size(alloc, 0); // FIXME: this will still allocate the buffer
         if (max_compute_size < best_compute_size) {
             best_compute_size = max_compute_size;
             best_order = gf->order;
@@ -1157,9 +1109,8 @@ int main(int argc, char ** argv) {
         "invalid");
 
     // allocate compute tensors
-    mem_compute_data.resize(max_compute_size);
     ctx_compute = ggml_init(ctx_compute_params);
-    alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
+    ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
     gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
     gf->order = best_order;
     gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@@ -1172,7 +1123,8 @@ int main(int argc, char ** argv) {
         &logits, tokens_input, target_probs,
         n_tokens, n_batch,
         params.common.use_flash,
-        params.common.use_checkpointing
+        params.common.use_checkpointing,
+        false
     );
 
     std::vector<llama_token> train_tokens;

Різницю між файлами не показано, бо вона завелика
+ 560 - 485
ggml-alloc.c


+ 39 - 65
ggml-alloc.h

@@ -6,88 +6,62 @@
 extern "C" {
 #endif
 
-struct ggml_backend;
-struct ggml_backend_buffer;
-struct ggml_backend_buffer_type;
+typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+typedef struct ggml_backend * ggml_backend_t;
 
-//
-// Legacy API
-//
-
-typedef struct ggml_allocr * ggml_allocr_t;
-
-// initialize allocator for use with CPU backend only
-GGML_API ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment);
-GGML_API ggml_allocr_t ggml_allocr_new_measure(size_t alignment);
-
-// initialize allocator for use with ggml-backend
-GGML_API ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer);
-GGML_API ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
-GGML_API ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend);
-
-GGML_API struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc);
-
-// tell the allocator to parse nodes following the order described in the list
-// you should call this if your graph are optimized to execute out-of-order
-GGML_API void   ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n);
-
-GGML_API void   ggml_allocr_free       (ggml_allocr_t alloc);
-GGML_API bool   ggml_allocr_is_measure (ggml_allocr_t alloc);
-GGML_API void   ggml_allocr_reset      (ggml_allocr_t alloc);
-GGML_API void   ggml_allocr_alloc      (ggml_allocr_t alloc, struct ggml_tensor * tensor);
-GGML_API size_t ggml_allocr_max_size   (ggml_allocr_t alloc);
-
-GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph);
+// Tensor allocator
+typedef struct ggml_tallocr * ggml_tallocr_t;
 
-//
-// ggml-backend v2 API
-//
+GGML_API ggml_tallocr_t ggml_tallocr_new(ggml_backend_buffer_t buffer);
+GGML_API void           ggml_tallocr_free(ggml_tallocr_t talloc);
+GGML_API void           ggml_tallocr_alloc(ggml_tallocr_t talloc, struct ggml_tensor * tensor);
 
-// Separate tensor and graph allocator objects
-// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
-// The original API is kept as a wrapper around the new API
+// Graph allocator
+/*
+  Example usage:
+    ggml_gallocr_t galloc = ggml_gallocr_new(ggml_bacckend_cpu_buffer_type());
 
-// Tensor allocator
-typedef struct ggml_tallocr * ggml_tallocr_t;
+    // optional: create a worst-case graph and reserve the buffers to avoid reallocations
+    ggml_gallocr_reserve(galloc, build_graph(max_batch));
 
-GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment);
-GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment);
-GGML_API ggml_tallocr_t ggml_tallocr_new_from_buft(struct ggml_backend_buffer_type * buft, size_t size);
-GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
-GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer);
-GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_buft(struct ggml_backend_buffer_type * buft);
-GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend);
+    // allocate the graph
+    struct ggml_cgraph * graph = build_graph(batch);
+    ggml_gallocr_alloc_graph(galloc, graph);
 
-GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc);
+    printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0));
 
-GGML_API void   ggml_tallocr_free       (ggml_tallocr_t talloc);
-GGML_API bool   ggml_tallocr_is_measure (ggml_tallocr_t talloc);
-GGML_API void   ggml_tallocr_reset      (ggml_tallocr_t talloc);
-GGML_API void   ggml_tallocr_alloc      (ggml_tallocr_t talloc, struct ggml_tensor * tensor);
-GGML_API size_t ggml_tallocr_max_size   (ggml_tallocr_t talloc);
+    // evaluate the graph
+    ggml_backend_graph_compute(backend, graph);
+*/
 
+// special tensor flags for use with the graph allocator:
+//   ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses
+//   ggml_set_output(): output tensors are never freed and never overwritten
 
-// Graph allocator
 typedef struct ggml_gallocr * ggml_gallocr_t;
 
-GGML_API ggml_gallocr_t ggml_gallocr_new(void);
-GGML_API void   ggml_gallocr_free(ggml_gallocr_t galloc);
+GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);
+GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);
+GGML_API void           ggml_gallocr_free(ggml_gallocr_t galloc);
 
-GGML_API void   ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n);
-GGML_API size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph);
+// pre-allocate buffers from a measure graph - does not allocate or modify the graph
+// call with a worst-case graph to avoid buffer reallocations
+// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
+// returns false if the buffer allocation failed
+GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
+GGML_API bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids);
 
-// Allocate tensors from the allocators given by the hash table
-GGML_API void   ggml_gallocr_alloc_graph_n(
-                    ggml_gallocr_t galloc,
-                    struct ggml_cgraph * graph,
-                    struct ggml_hash_set hash_set,
-                    ggml_tallocr_t * hash_node_talloc);
+// automatic reallocation if the topology changes when using a single buffer
+// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
+GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
 
+GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
 
 // Utils
 // Create a buffer and allocate all the tensors in a ggml_context
-GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
-GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
 
 #ifdef  __cplusplus
 }

Різницю між файлами не показано, бо вона завелика
+ 227 - 257
ggml-backend.c


+ 5 - 10
ggml-backend.h

@@ -130,11 +130,7 @@ extern "C" {
 
         // in build_graph:
         build_graph(...) {
-            // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer)
-            alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu);
-            ggml_allocr_alloc(alloc_cpu, tensor);
-
-            // manually assigning nodes to a backend (optional, shouldn't be needed in most cases)
+            // manually assign nodes to a backend (optional, should not be needed in most cases)
             struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
             ggml_backend_sched_set_node_backend(sched, node, backend_gpu);
         }
@@ -164,20 +160,19 @@ extern "C" {
     GGML_API ggml_backend_sched_t  ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
     GGML_API void                  ggml_backend_sched_free(ggml_backend_sched_t sched);
     // Initialize backend buffers from a measure graph
-    GGML_API void                  ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
+    GGML_API bool                  ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
     // Get the number of splits of the last graph
     GGML_API int                   ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
 
-    GGML_API ggml_tallocr_t        ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend);
-    GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
+    GGML_API size_t                ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
 
     GGML_API void                  ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
     GGML_API ggml_backend_t        ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
 
     // Allocate and compute graph on the backend scheduler
-    GGML_API void                  ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API bool                  ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
 
-    // Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs
+    // Reset all assignments and allocators - must be called before changing the node backends
     GGML_API void                  ggml_backend_sched_reset(ggml_backend_sched_t sched);
 
     // Set a callback to be called for each resulting node during graph compute

+ 19 - 9
ggml.c

@@ -2649,7 +2649,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         /*.nb           =*/ { 0, 0, 0, 0 },
         /*.op           =*/ GGML_OP_NONE,
         /*.op_params    =*/ { 0 },
-        /*.is_param     =*/ false,
+        /*.flags        =*/ 0,
         /*.grad         =*/ NULL,
         /*.src          =*/ { NULL },
         /*.perf_runs    =*/ 0,
@@ -6551,7 +6551,7 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
 void ggml_set_param(
         struct ggml_context * ctx,
         struct ggml_tensor * tensor) {
-    tensor->is_param = true;
+    tensor->flags |= GGML_TENSOR_FLAG_PARAM;
 
     GGML_ASSERT(tensor->grad == NULL);
     tensor->grad = ggml_dup_tensor(ctx, tensor);
@@ -15367,7 +15367,7 @@ static struct ggml_tensor * ggml_recompute_graph_node(
         return NULL;
     }
 
-    if (node->is_param) {
+    if (node->flags & GGML_TENSOR_FLAG_PARAM) {
         return node;
     }
 
@@ -15401,7 +15401,7 @@ static struct ggml_tensor * ggml_recompute_graph_node(
 
     clone->op       = node->op;
     clone->grad     = node->grad;
-    clone->is_param = node->is_param;
+    clone->flags    = node->flags;
     clone->extra    = node->extra;
     for (int k = 0; k < GGML_MAX_DIMS; ++k) {
         clone->nb[k] = node->nb[k];
@@ -16433,7 +16433,7 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
     for (int i = 0; i < gf->n_nodes; i++) {
         struct ggml_tensor * node = gf->nodes[i];
 
-        if (node->is_param) {
+        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
             GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
             ggml_build_forward_expand(gb, node->grad);
         }
@@ -17918,7 +17918,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
         GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
                 i,
                 node->ne[0], node->ne[1], node->ne[2],
-                ggml_op_name(node->op), node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
+                ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ", node->perf_runs,
                 (double) node->perf_cycles  / (double) ggml_cycles_per_ms(),
                 (double) node->perf_cycles  / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
                 (double) node->perf_time_us / 1000.0,
@@ -18011,7 +18011,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
             continue;
         }
 
-        if (node->is_param) {
+        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
             snprintf(color, sizeof(color), "yellow");
         } else if (node->grad) {
             if (ggml_graph_find(gf, node)) {
@@ -18185,7 +18185,7 @@ static enum ggml_opt_result ggml_opt_adam(
     int np = 0;
     int64_t nx = 0;
     for (int i = 0; i < gf->n_nodes; ++i) {
-        if (gf->nodes[i]->is_param) {
+        if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
             GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
 
             GGML_ASSERT(np < GGML_MAX_PARAMS);
@@ -18548,7 +18548,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
     int np = 0;
     int nx = 0;
     for (int i = 0; i < gf->n_nodes; ++i) {
-        if (gf->nodes[i]->is_param) {
+        if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
             GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
 
             GGML_ASSERT(np < GGML_MAX_PARAMS);
@@ -19023,6 +19023,16 @@ enum ggml_opt_result ggml_opt_resume_g(
 
 ////////////////////////////////////////////////////////////////////////////////
 
+void ggml_set_input(struct ggml_tensor * tensor) {
+    tensor->flags |= GGML_TENSOR_FLAG_INPUT;
+}
+
+void ggml_set_output(struct ggml_tensor * tensor) {
+    tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
 void ggml_quantize_init(enum ggml_type type) {
     ggml_critical_section_start();
 

+ 15 - 3
ggml.h

@@ -505,11 +505,17 @@ extern "C" {
 
     enum ggml_log_level {
         GGML_LOG_LEVEL_ERROR = 2,
-        GGML_LOG_LEVEL_WARN = 3,
-        GGML_LOG_LEVEL_INFO = 4,
+        GGML_LOG_LEVEL_WARN  = 3,
+        GGML_LOG_LEVEL_INFO  = 4,
         GGML_LOG_LEVEL_DEBUG = 5
     };
 
+    enum ggml_tensor_flag {
+        GGML_TENSOR_FLAG_INPUT  = 1,
+        GGML_TENSOR_FLAG_OUTPUT = 2,
+        GGML_TENSOR_FLAG_PARAM  = 4,
+    };
+
     // ggml object
     struct ggml_object {
         size_t offs;
@@ -543,7 +549,7 @@ extern "C" {
         // op params - allocated as int32_t for alignment
         int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
 
-        bool is_param;
+        int32_t flags;
 
         struct ggml_tensor * grad;
         struct ggml_tensor * src[GGML_MAX_SRC];
@@ -2092,6 +2098,12 @@ extern "C" {
             ggml_opt_callback callback,
             void * callback_data);
 
+    //
+    // tensor flags
+    //
+    GGML_API void ggml_set_input(struct ggml_tensor * tensor);
+    GGML_API void ggml_set_output(struct ggml_tensor * tensor);
+
     //
     // quantization
     //

+ 95 - 86
llama.cpp

@@ -1872,8 +1872,6 @@ struct llama_context {
     // memory buffers used to evaluate the model
     std::vector<uint8_t> buf_compute_meta;
     ggml_backend_sched_t sched = nullptr;
-    // allocator for the input tensors
-    ggml_tallocr * alloc = nullptr;
 
     // input tensors
     ggml_backend_buffer_t buf_input = nullptr;
@@ -7199,12 +7197,10 @@ struct llm_build_context {
 
 static struct ggml_cgraph * llama_build_graph(
          llama_context & lctx,
-     const llama_batch & batch) {
+     const llama_batch & batch,
+                  bool   worst_case) {
     const auto & model = lctx.model;
 
-    // check if we should build the worst-case graph (for memory measurement)
-    const bool worst_case = ggml_tallocr_is_measure(lctx.alloc);
-
     // this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
     llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
         if (il >= 0) {
@@ -7225,77 +7221,6 @@ static struct ggml_cgraph * llama_build_graph(
 
     struct llm_build_context llm(lctx, batch, cb, worst_case);
 
-    //
-    // set input data
-    //
-
-    if (!ggml_tallocr_is_measure(lctx.alloc)) {
-        if (batch.token) {
-            const int64_t n_tokens = batch.n_tokens;
-
-            ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
-        }
-
-        if (batch.embd) {
-            const int64_t n_embd   = llm.n_embd;
-            const int64_t n_tokens = batch.n_tokens;
-
-            ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
-        }
-
-        if (batch.pos) {
-            const int64_t n_tokens = batch.n_tokens;
-
-            ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
-        }
-
-        {
-            const int64_t n_kv     = llm.n_kv;
-            const int64_t n_tokens = batch.n_tokens;
-
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
-            float * data = (float *) lctx.inp_KQ_mask->data;
-
-            for (int h = 0; h < 1; ++h) {
-                for (int j = 0; j < n_tokens; ++j) {
-                    const llama_pos    pos    = batch.pos[j];
-                    const llama_seq_id seq_id = batch.seq_id[j][0];
-
-                    for (int i = 0; i < n_kv; ++i) {
-                        float f;
-                        if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
-                            (llm.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
-                            f = -INFINITY;
-                        } else {
-                            f = 0;
-                        }
-                        data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
-                    }
-                }
-            }
-        }
-
-        if (llm.do_rope_shift) {
-            const int64_t n_ctx = llm.n_ctx;
-
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
-            int32_t * data = (int32_t *) lctx.inp_K_shift->data;
-
-            for (int i = 0; i < n_ctx; ++i) {
-                data[i] = lctx.kv_self.cells[i].delta;
-            }
-        }
-
-        {
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
-            float * data = (float *) lctx.inp_sum->data;
-
-            for (int i = 0; i < batch.n_tokens; ++i) {
-                data[i] = 1.0f/float(batch.n_tokens);
-            }
-        }
-    }
-
     llm.init();
 
     switch (model.arch) {
@@ -7384,6 +7309,83 @@ static struct ggml_cgraph * llama_build_graph(
     return result;
 }
 
+static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
+    //
+    // set input data
+    //
+
+    const auto & hparams = lctx.model.hparams;
+    const auto & cparams = lctx.cparams;
+    const auto & kv_self = lctx.kv_self;
+
+    if (batch.token) {
+        const int64_t n_tokens = batch.n_tokens;
+
+        ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
+    }
+
+    if (batch.embd) {
+        const int64_t n_embd   = hparams.n_embd;
+        const int64_t n_tokens = batch.n_tokens;
+
+        ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
+    }
+
+    if (batch.pos) {
+        const int64_t n_tokens = batch.n_tokens;
+
+        ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
+    }
+
+    {
+        const int64_t n_kv     = kv_self.n;
+        const int64_t n_tokens = batch.n_tokens;
+
+        assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
+
+        float * data = (float *) lctx.inp_KQ_mask->data;
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                const llama_pos    pos    = batch.pos[j];
+                const llama_seq_id seq_id = batch.seq_id[j][0];
+
+                for (int i = 0; i < n_kv; ++i) {
+                    float f;
+                    if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
+                        f = -INFINITY;
+                    } else {
+                        f = 0;
+                    }
+                    data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+                }
+            }
+        }
+    }
+
+
+    {
+        assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
+        float * data = (float *) lctx.inp_sum->data;
+
+        for (int i = 0; i < batch.n_tokens; ++i) {
+            data[i] = 1.0f/float(batch.n_tokens);
+        }
+    }
+
+    if (kv_self.has_shift) {
+        const int64_t n_ctx = cparams.n_ctx;
+
+        assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
+
+        int32_t * data = (int32_t *) lctx.inp_K_shift->data;
+
+        for (int i = 0; i < n_ctx; ++i) {
+            data[i] = lctx.kv_self.cells[i].delta;
+        }
+    }
+}
+
 // decode a batch of tokens by evaluating the transformer
 //
 //   - lctx:      llama context
@@ -7482,7 +7484,7 @@ static int llama_decode_internal(
     ggml_backend_sched_reset(lctx.sched);
     ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
-    ggml_cgraph * gf = llama_build_graph(lctx, batch);
+    ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
 
     // the output is always the last tensor in the graph
     struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
@@ -7527,6 +7529,9 @@ static int llama_decode_internal(
     if (lctx.backend_cpu != nullptr) {
         ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
     }
+
+    llama_set_inputs(lctx, batch);
+
     ggml_backend_sched_graph_compute(lctx.sched, gf);
 
     // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
@@ -11278,23 +11283,27 @@ struct llama_context * llama_new_context_with_model(
             ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
 
             ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);
-            ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
 
             // build worst-case graph
             int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
             int n_past = cparams.n_ctx - n_tokens;
             llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-            ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0));
+            ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
 
             // initialize scheduler with the worst-case graph
-            ggml_backend_sched_init_measure(ctx->sched, gf);
-            ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
+            if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
+                LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
+                llama_free(ctx);
+                return nullptr;
+            }
 
-            for (ggml_backend_t backend : ctx->backends) {
-                ggml_backend_buffer_t buf = ggml_backend_sched_get_buffer(ctx->sched, backend);
+            for (size_t i = 0; i < ctx->backends.size(); i++) {
+                ggml_backend_t backend = ctx->backends[i];
+                ggml_backend_buffer_type_t buft = backend_buft[i];
+                size_t size = ggml_backend_sched_get_buffer_size(ctx->sched, backend);
                 LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
-                        ggml_backend_buffer_name(buf),
-                        ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
+                        ggml_backend_buft_name(buft),
+                        size / 1024.0 / 1024.0);
             }
 
             // note: the number of splits during measure is higher than during inference due to the kv shift

+ 1 - 1
scripts/sync-ggml.last

@@ -1 +1 @@
-2c7cf49810d523b9632da393a9e8270b60bf3b24
+5070f078a67c18c11736e78316ab715ca9afde16

Деякі файли не було показано, через те що забагато файлів було змінено