Răsfoiți Sursa

ggml : implement backward pass for llama + small training-llama-from-scratch example (#1360)

* implement 8 of 14 missing backward pass operations used by llama

- GGML_OP_ADD_AT
- GGML_OP_CPY
- GGML_OP_MUL_MAT (src0.grad)
- GGML_OP_PERMUTE
- GGML_OP_RESHAPE
- GGML_OP_SCALE
- GGML_OP_TRANSPOSE
- GGML_OP_VIEW

implement additional ggml operation GGML_OP_ADD_AT, which is necessary for backward pass of GGML_OP_VIEW.

this operation adds src1 to src0 with data offset, i.e. to view(src0, ..., offset).
the values are return in a tensor size of src0. values outside of [data+offset:data+offset+nbytes(src1)] are just the original values from src0.

still missing backward passes for llama:

- GGML_OP_DIAG_MASK_INF
- GGML_OP_GET_ROWS
- GGML_OP_RMS_NORM
- GGML_OP_ROPE
- GGML_OP_SILU
- GGML_OP_SOFT_MAX

* implement 5 of 6 missing backward pass operations used by llama

- GGML_OP_DIAG_MASK_INF
- GGML_OP_GET_ROWS
- GGML_OP_RMS_NORM
- GGML_OP_SILU
- GGML_OP_SOFT_MAX

add necessary ggml operations GGML_OP_ADD1, GGML_OP_SILU_BACK, GGML_OP_RMS_NORM_BACK, GGML_OP_DIAG_MASK_ZERO, and GGML_OP_ROPE_BACK

GGML_OP_ADD1 is necessary to add a scalar value in the backward pass of GGML_OP_SOFT_MAX
GGML_OP_ADD1 could also be replaced by using GGML_OP_ADD and GGML_OP_REPEAT, but the performance would be worse. additionally GGML_OP_REPEAT will return unexpected value when the the input to GGML_OP_SOFT_MAX contains only a single scalar. in this case GGML_OP_REPEAT will not return the value that should be repeated (src1) but the value which shape the result should take (src0). So in this case it can not replace GGML_OP_ADD1.

GGML_OP_SILU_BACK, GGML_OP_RMS_NORM_BACK and GGML_OP_ROPE_BACK are necessary for backward pass of GGML_OP_SILU, GGML_OP_RMS_NORM and GGML_OP_ROPE. The backward pass for these functions cannot be easily composed of existing operations. Since the backward pass builds a computation graph we need operations forward pass implementations of the the required backward passes. Sounds a bit confusing at first, I know...

GGML_OP_DIAG_MASK_ZERO is necessary for backward pass of GGML_OP_DIAG_MASK_INF.

Some operations where previously inplace-only. for backward pass there needs to be non-inplace variants.
staying consistent with other operations that have non-inplace and inplace variants, the operations are changed to non-inplace and
functions with "_inplace" are added which are inplace.
in llama we need to call the inplace variants so that it is implemented as before.
for llama backward pass we need to use the non-inplace variants.

still not completely implemented backward passes for llama:

- GGML_OP_ROPE: needs forward pass for GGML_OP_ROPE_BACK
- GGML_OP_GET_ROWS: only necessary for tokenizer

* norm & rms_norm can not be threaded:

after investigation rms norm for quite some time I come to the conclusion that neither norm, nor rms_norm can be threaded, because we need mean over all items, not just of the slices each thread sees.

* remove already resolved TODO

* implement backward pass of ggml_rope and ggml_rope_back

* implement backward pass for ggml_get_rows and for new operation ggml_get_rows_back

* add test-grad0.c

* use GGML_PRINT_DEBUG for debug messages which will otherwise flood the console

* test both gradients of mul_mat

* disable graph dot export as it floods console

* bug fixes for silu_back

* successfully test silu backward

* bug fix for scale backward pass

use sum instead of mean for gradient of scalar scale parameter

* successfully test scale backward

* improve performance of sum backward pass

use add1(x,y) instead of add(x,repeat(y,x))

* improve performance of sqr backward pass

use scale(x,y) instead of mul(x,repeat(y,x))

* successfully test rope backward

* bug fix for cpy backward pass

* successfully test cpy backward

* bug fix for reshape backward pass

* successfully test reshape backward

* add test-opt.c

this uses ggml_opt to train a,b for minimal e=sum(sqr(c - a*b)) for random initial a,b,c

* correctly implement softmax backward pass using new operation ggml_diag

ggml_diag constructs diagonal matrices with entries.
ggml_diag(shape[a,1,c,d]) -> shape[a,a,c,d]

* successfully test soft_max backward

* align shape annotations

* add shape annotations for llama

* de-duplicate ggml_forward_dup code taking care of contiguous tensors of same type.

with this we can duplicate tensor of any typ as long as they are contiguous.

* fix ggml_compute_forward_dup_same_cont for when nelements < nthreads

when more threads are used than elements exist ie1 was less than ie0, resulting in invalid negative byte count argument in memcpy

* bug fix for add_at forward

required for view backward pass

src0 values must be copied to dst, because during addition we don't touch all dst elements in contrast to the normal add function.

* successfully test view backward

* minor code format improvement

* fix ggml_forward_add functions to work correctly with transposed tensors

uses the same logic as in ggml_compute_forward_add_q_f32, but make it consistent across all ggml_compute_forward_add_... functions.
this also slightly changes the mem access pattern of the different threads to works as in ggml_compute_forward_add_q_f32.

* fix ggml_forward_add1 functions to work correctly with transposed tensors

uses the same logic as in ggml_compute_forward_add1_q_f32, but make it consistent across all ggml_compute_forward_add1_... functions.
this also slightly changes the mem access pattern of the different threads to works as in ggml_compute_forward_add1_q_f32.

* test-grad0.c : add print_elements to help with debugging

* successfully test permute backward

* some minor test-grad0 fixes

* fix sub, mul and div functions to work correctly with transposed tensors

uses the same logic as in add

* implement ggml_cont backward pass

* successfully test transpose backward and permute for all permutations

also test sub, mul and div up to max n_dims

* test-grad0.c add TODO for view_2d and view_3d

add_at (required for view backward pass) is a bit tricky for n_dims > 1.

* fix comments

* successfully test diag_mask_inf and diag_mask_zero backward

* test-grad0 : fix test for div

nargs and ndims was swapped, corrupting the stack

* fix diag_mask to work with non-inplace input

* move dup call into the actual add_at functions

* fix get rows backward pass

* successfully test get_rows backward

* fix view backward pass

add nb parameters to add_at like in view.
together with offset they define how to view dst and src0 during the add_at operation.

* successfully test backward pass of view_1d, view_2d and view_3d

* fix backward pass for rms_norm

I would have used formulas from other frameworks, but they differed so I could not decide which is correct.
Instead it was derived here in comment using manual forward-backward automatic differention of rms_norm and simplification.

* successfully test backward pass of rms_norm

some tests may fail when gradients are large.
could not find a satisfying configuration to check for abs error and relative error that passes all tests while still actually testing the results with tight enough error bounds.
when looking at the values the "failed" tests look actually ok. for example:

rms_norm: ndims=2, i=0, k=2, x0=0.000153, xm=0.000053, xp=0.000253, f0=0.278594, f1=0.086213, g0=961.905457, g1=966.064941, eps=0.000100, error_abs=4.159485, error_rel=0.004324

it is due to the test logic in check_gradients that they fail.

* add todos for llama backward pass

- implementation for ADD1 backward pass should probably use sum instead of mean (but this backward pass is not required)
- repeat is not yet tested and looks like it only works for single element src0 inputs.

* add operation ggml_sum_rows

ggml_sum_rows(shape[a,b,c,d]) -> shape[1,b,c,d]

* add missing GGML_OP_SUM_ROWS

* fix backward pass for repeat

requires ggml_sum_rows

* successfully test backward pass of repeat

* update quantization types in switch-case of add_at and add1

* add baby-llama example training a very small llama model from scratch to output a sinusoidal wave.

had to increase maximum number of optimization parameters to train from scratch.

* fix softmax in baby-llama example

* switching from training with adam to lbfgs produces much better results in the baby-llama example

* train with two examples, creating new tensors each time..

* fix bug when using ggml_opt to optimize params in one context and use a renewable context for eval and opt

when not keeping gradients of model parameters they are overwritten by tensors created by opt, which may be invalid after opt context is renewed.
so we need to keep the original gradients and make dups for opt

* train on multiple examples, generate & print tokens with trained model afterwards

ctx0 for evaluation and optimization is renewed for each sample

* add ggml_reshape_1d, ggml_reshape_4d and ggml_view_4d

* fix soft_max backward pass for input->ne[1] != 1

* add ggml_log operation necessary for cross entropy loss

* add test for ggml_log gradients

* implement backward pass for ggml_sum_rows, necessary for cross entropy loss

* implement ggml_repeat support for rank > 2 tensors

* add test for ggml_sum_rows gradients

* fix training get_example_targets

predict the next token, not the current token!

* add square_error_loss and cross_entropy_loss functions

* optimize loss over multiple samples

this increases computation graph, need parallel batched forward for more efficiency.

* fix backward pass for add_at and change arguments to have same order as in view

* add ggml_set(ctx, a, b) to set b in view of a and return modified a

necessary to set values into kv_self cache and properly propagate the gradients

* fix kv_self gradients for training

use ggml_set instead of ggml_cpy to set kv_self cache with properly propagating gradients

* replace inplace operations for training with copying operations to allow gradient propagation

* add GGML_ASSERT to catch ggml_rope and back value errors

* add trainable lora-only model with all big matrices C split into A,B with A*B=C

this is not a lora-finetune, but the whole model changed to have only low-rank "lora" matrices.

training this instead of the normal model resulted in much worse results though...

* vastly improve training results

instead of logit targets 0 and 1 use -1 and +1.

* shorten code using a variable

* change name of GGML_OP_ADD_AT to GGML_OP_ACC

* smaller default values for baby llama model parameters

* update static assert of GGML_OP_COUNT

* remove shape annotations in llama_eval_internal

* revert disabling of threading for rms_norm and norm

* rename print functions in baby-llama example

* fix call to ggml_set_name

* add missing include for strcmp, etc

* remove trailing whitespace

* reduce number of test-grad0 iterations

avoid exceeding timeout of automated tests

* remove busy loop that was used as sleep for slower sinus wave generation

* disable slow tests grad0 and opt to avoid exceeding timeouts

* c++ in baby-llama example

use c++ includes instead of c includes
use std::min, std::max instead of MIN, MAX macros

* c++ in baby-llama example

use c++ includes instead of c includes
use std::min, std::max instead of MIN, MAX macros

* ggml : fix compiler warnings + cosmetic changes

* ggml : fix nullptr derefs in GGML_OP_CONT and GGML_OP_RESHAPE back

* swap arguments to vDSP_vdiv call

documentation for vDSP_vdiv states: "Note that B comes before A!"

* swap arguments to vDSP_vdiv call

documentation for vDSP_vdiv states: "Note that B comes before A!"

* ggml : swap vDSP_vsub args as per documentation

* add parallel batched forward function for baby-llama training

* cleanup code for batched training

* remove trailing whitespace

* minor : fix compiler warnings + indentation style

* ggml : fix null ptr deref in backward pass

* ggml : remove Q4_2 remnants

* ggml : fix clang-tidy warnings

* baby-llama : couple of clang-tidy warnings

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
xaedes 2 ani în urmă
părinte
comite
f954edda93
9 a modificat fișierele cu 3849 adăugiri și 63 ștergeri
  1. 1 0
      examples/CMakeLists.txt
  2. 4 0
      examples/baby-llama/CMakeLists.txt
  3. 1687 0
      examples/baby-llama/baby-llama.cpp
  4. 617 49
      ggml.c
  5. 193 7
      ggml.h
  6. 9 7
      llama.cpp
  7. 2 0
      tests/CMakeLists.txt
  8. 1131 0
      tests/test-grad0.c
  9. 205 0
      tests/test-opt.c

+ 1 - 0
examples/CMakeLists.txt

@@ -36,4 +36,5 @@ else()
     add_subdirectory(embedding)
     add_subdirectory(save-load-state)
     add_subdirectory(benchmark)
+    add_subdirectory(baby-llama)
 endif()

+ 4 - 0
examples/baby-llama/CMakeLists.txt

@@ -0,0 +1,4 @@
+set(TARGET baby-llama)
+add_executable(${TARGET} baby-llama.cpp)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)

+ 1687 - 0
examples/baby-llama/baby-llama.cpp

@@ -0,0 +1,1687 @@
+#include "ggml.h"
+#include <vector>
+#include <cassert>
+#include <random>
+#include <cstring>
+
+float frand() {
+    return (float)rand()/(float)RAND_MAX;
+}
+
+struct random_normal_distribution {
+    std::mt19937 gen;
+    std::normal_distribution<float> nd;
+    float min;
+    float max;
+};
+
+void init_random_normal_distribution(struct random_normal_distribution * rnd, int seed, float mean, float std, float min, float max) {
+    rnd->gen = std::mt19937(seed);
+    rnd->nd = std::normal_distribution<float>{mean, std};
+    rnd->min = min;
+    rnd->max = max;
+}
+
+float frand_normal(struct random_normal_distribution * rnd) {
+    const float r = rnd->nd(rnd->gen);
+    return ((r < rnd->min) ? (rnd->min) : (r > rnd->max) ? (rnd->max) : r);
+}
+
+struct ggml_tensor * randomize_tensor(
+        struct ggml_tensor * tensor,
+        int ndims,
+        const int64_t ne[],
+        float fmin,
+        float fmax) {
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((float *)tensor->data)[i0] = frand()*(fmax - fmin) + fmin;
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((float *)tensor->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return tensor;
+}
+
+struct ggml_tensor * randomize_tensor_normal(
+        struct ggml_tensor * tensor,
+        int ndims,
+        const int64_t ne[],
+        struct random_normal_distribution * rnd) {
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((float *)tensor->data)[i0] = frand_normal(rnd);
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((float *)tensor->data)[i1*ne[0] + i0] = frand_normal(rnd);
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd);
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd);
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return tensor;
+}
+
+struct llama_hparams {
+    uint32_t n_vocab = 32000;
+    uint32_t n_ctx   = 512;   // this is provided as user input?
+    uint32_t n_embd  = 4096;
+    uint32_t n_mult  = 4;
+    uint32_t n_head  = 32;
+    uint32_t n_layer = 32;
+    uint32_t n_rot   = 64;
+
+    bool operator!=(const llama_hparams & other) const {
+        return memcmp(this, &other, sizeof(llama_hparams));
+    }
+};
+
+uint32_t get_n_ff(const struct llama_hparams* hparams) {
+    const uint32_t n_ff = ((2*(4*hparams->n_embd)/3 + hparams->n_mult - 1)/hparams->n_mult)*hparams->n_mult;
+    return n_ff;
+}
+
+struct llama_hparams_lora {
+    uint32_t n_vocab = 32000;
+    uint32_t n_ctx   = 512;   // this is provided as user input?
+    uint32_t n_embd  = 4096;
+    uint32_t n_mult  = 4;
+    uint32_t n_head  = 32;
+    uint32_t n_layer = 32;
+    uint32_t n_rot   = 64;
+    uint32_t n_lora  = 64;
+
+    bool operator!=(const llama_hparams & other) const {
+        return memcmp(this, &other, sizeof(llama_hparams));
+    }
+};
+
+struct llama_layer {
+    // normalization
+    struct ggml_tensor * attention_norm;
+
+    // attention
+    struct ggml_tensor * wq;
+    struct ggml_tensor * wk;
+    struct ggml_tensor * wv;
+    struct ggml_tensor * wo;
+
+    // normalization
+    struct ggml_tensor * ffn_norm;
+
+    // ff
+    struct ggml_tensor * w1;
+    struct ggml_tensor * w2;
+    struct ggml_tensor * w3;
+};
+
+struct llama_layer_lora {
+    // normalization
+    struct ggml_tensor * attention_norm;
+
+    // attention
+    struct ggml_tensor * wqa;
+    struct ggml_tensor * wqb;
+    struct ggml_tensor * wka;
+    struct ggml_tensor * wkb;
+    struct ggml_tensor * wva;
+    struct ggml_tensor * wvb;
+    struct ggml_tensor * woa;
+    struct ggml_tensor * wob;
+
+    // normalization
+    struct ggml_tensor * ffn_norm;
+
+    // ff
+    struct ggml_tensor * w1;
+    struct ggml_tensor * w2;
+    struct ggml_tensor * w3;
+};
+
+
+struct llama_kv_cache {
+    struct ggml_context * ctx = NULL;
+
+    struct ggml_tensor * k;
+    struct ggml_tensor * v;
+
+    // llama_ctx_buffer buf;
+
+    int n; // number of tokens currently in the cache
+};
+
+struct llama_model {
+    struct ggml_context * ctx = NULL;
+
+    llama_hparams hparams;
+
+    struct ggml_tensor * tok_embeddings;
+
+    struct ggml_tensor * norm;
+    struct ggml_tensor * output;
+
+    std::vector<llama_layer> layers;
+};
+
+struct llama_model_lora {
+    struct ggml_context * ctx = NULL;
+
+    llama_hparams_lora hparams;
+
+    struct ggml_tensor * tok_embeddings;
+
+    struct ggml_tensor * norm;
+    struct ggml_tensor * outputa;
+    struct ggml_tensor * outputb;
+
+    std::vector<llama_layer_lora> layers;
+};
+
+void init_model(struct llama_model * model) {
+    const auto & hparams = model->hparams;
+
+    const uint32_t n_embd  = hparams.n_embd;
+    const uint32_t n_layer = hparams.n_layer;
+    const uint32_t n_vocab = hparams.n_vocab;
+
+    const uint32_t n_ff = get_n_ff(&hparams);
+
+    struct ggml_context * ctx = model->ctx;
+
+    model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); // ("tok_embeddings.weight", {n_embd, n_vocab});
+    model->norm           = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);          // ("norm.weight",           {n_embd});
+    model->output         = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); // ("output.weight",         {n_embd, n_vocab});
+
+    model->layers.resize(n_layer);
+    for (uint32_t i = 0; i < n_layer; ++i) {
+        auto & layer = model->layers[i];
+
+        // std::string layers_i = "layers." + std::to_string(i);
+
+        layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // (layers_i + ".attention_norm.weight", {n_embd});
+
+        layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);     // (layers_i + ".attention.wq.weight", {n_embd, n_embd});
+        layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);     // (layers_i + ".attention.wk.weight", {n_embd, n_embd});
+        layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);     // (layers_i + ".attention.wv.weight", {n_embd, n_embd});
+        layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);     // (layers_i + ".attention.wo.weight", {n_embd, n_embd});
+
+        layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);       // (layers_i + ".ffn_norm.weight", {n_embd});
+
+        layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd,   n_ff);     // (layers_i + ".feed_forward.w1.weight", {n_embd,   n_ff});
+        layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,   n_ff, n_embd);     // (layers_i + ".feed_forward.w2.weight", {  n_ff,   n_embd});
+        layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd,   n_ff);     // (layers_i + ".feed_forward.w3.weight", {n_embd,   n_ff});
+    }
+}
+
+
+void init_model_lora(struct llama_model_lora * model) {
+    const auto & hparams = model->hparams;
+
+    const uint32_t n_embd  = hparams.n_embd;
+    const uint32_t n_mult  = hparams.n_mult;
+    const uint32_t n_layer = hparams.n_layer;
+    const uint32_t n_vocab = hparams.n_vocab;
+    const uint32_t n_lora  = hparams.n_lora;
+
+    const uint32_t n_ff = ((2*(4*n_embd)/3 + n_mult - 1)/n_mult)*n_mult;
+
+    struct ggml_context * ctx = model->ctx;
+
+    model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); // ("tok_embeddings.weight", {n_embd, n_vocab});
+    model->norm           = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);          // ("norm.weight",           {n_embd});
+    model->outputa        = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_vocab); // ("output.weight",         {n_embd, n_vocab});
+    model->outputb        = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd,  n_lora); // ("output.weight",         {n_embd, n_vocab});
+
+    model->layers.resize(n_layer);
+    for (uint32_t i = 0; i < n_layer; ++i) {
+        auto & layer = model->layers[i];
+
+        // std::string layers_i = "layers." + std::to_string(i);
+
+        layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // (layers_i + ".attention_norm.weight", {n_embd});
+
+        layer.wqa = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd);    // (layers_i + ".attention.wq.weight", {n_embd, n_embd});
+        layer.wqb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora);    // (layers_i + ".attention.wq.weight", {n_embd, n_embd});
+        layer.wka = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd);    // (layers_i + ".attention.wk.weight", {n_embd, n_embd});
+        layer.wkb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora);    // (layers_i + ".attention.wk.weight", {n_embd, n_embd});
+        layer.wva = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd);    // (layers_i + ".attention.wv.weight", {n_embd, n_embd});
+        layer.wvb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora);    // (layers_i + ".attention.wv.weight", {n_embd, n_embd});
+        layer.woa = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd);    // (layers_i + ".attention.wo.weight", {n_embd, n_embd});
+        layer.wob = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora);    // (layers_i + ".attention.wo.weight", {n_embd, n_embd});
+
+        layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);       // (layers_i + ".ffn_norm.weight", {n_embd});
+
+        layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd,   n_ff);     // (layers_i + ".feed_forward.w1.weight", {n_embd,   n_ff});
+        layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,   n_ff, n_embd);     // (layers_i + ".feed_forward.w2.weight", {  n_ff,   n_embd});
+        layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd,   n_ff);     // (layers_i + ".feed_forward.w3.weight", {n_embd,   n_ff});
+    }
+}
+
+void set_param_model(struct llama_model * model) {
+    const auto& hparams = model->hparams;
+
+    const uint32_t n_layer = hparams.n_layer;
+
+    struct ggml_context* ctx = model->ctx;
+
+    ggml_set_param(ctx, model->tok_embeddings);
+    ggml_set_param(ctx, model->norm);
+    ggml_set_param(ctx, model->output);
+
+    for (uint32_t i = 0; i < n_layer; ++i) {
+        auto & layer = model->layers[i];
+
+        ggml_set_param(ctx, layer.attention_norm);
+        ggml_set_param(ctx, layer.wq);
+        ggml_set_param(ctx, layer.wk);
+        ggml_set_param(ctx, layer.wv);
+        ggml_set_param(ctx, layer.wo);
+        ggml_set_param(ctx, layer.ffn_norm);
+        ggml_set_param(ctx, layer.w1);
+        ggml_set_param(ctx, layer.w2);
+        ggml_set_param(ctx, layer.w3);
+    }
+}
+
+void set_param_model_lora(struct llama_model_lora * model) {
+    const auto& hparams = model->hparams;
+
+    const uint32_t n_layer = hparams.n_layer;
+
+    struct ggml_context* ctx = model->ctx;
+
+    ggml_set_param(ctx, model->tok_embeddings);
+    ggml_set_param(ctx, model->norm);
+    ggml_set_param(ctx, model->outputa);
+    ggml_set_param(ctx, model->outputb);
+
+    for (uint32_t i = 0; i < n_layer; ++i) {
+        auto & layer = model->layers[i];
+
+        ggml_set_param(ctx, layer.attention_norm);
+        ggml_set_param(ctx, layer.wqa);
+        ggml_set_param(ctx, layer.wqb);
+        ggml_set_param(ctx, layer.wka);
+        ggml_set_param(ctx, layer.wkb);
+        ggml_set_param(ctx, layer.wva);
+        ggml_set_param(ctx, layer.wvb);
+        ggml_set_param(ctx, layer.woa);
+        ggml_set_param(ctx, layer.wob);
+        ggml_set_param(ctx, layer.ffn_norm);
+        ggml_set_param(ctx, layer.w1);
+        ggml_set_param(ctx, layer.w2);
+        ggml_set_param(ctx, layer.w3);
+    }
+}
+
+void randomize_model(struct llama_model * model, int seed, float mean, float std, float min, float max) {
+    const auto & hparams = model->hparams;
+
+    const uint32_t n_layer = hparams.n_layer;
+
+    struct random_normal_distribution rnd;
+    init_random_normal_distribution(&rnd, seed, mean, std, min, max);
+    randomize_tensor_normal(model->tok_embeddings, model->tok_embeddings->n_dims, model->tok_embeddings->ne, &rnd);
+    randomize_tensor_normal(model->norm,           model->norm->n_dims,           model->norm->ne,           &rnd);
+    randomize_tensor_normal(model->output,         model->output->n_dims,         model->output->ne,         &rnd);
+
+    for (uint32_t i = 0; i < n_layer; ++i) {
+        auto & layer = model->layers[i];
+        randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd);
+
+        randomize_tensor_normal(layer.wq, layer.wq->n_dims, layer.wq->ne, &rnd);
+        randomize_tensor_normal(layer.wk, layer.wk->n_dims, layer.wk->ne, &rnd);
+        randomize_tensor_normal(layer.wv, layer.wv->n_dims, layer.wv->ne, &rnd);
+        randomize_tensor_normal(layer.wo, layer.wo->n_dims, layer.wo->ne, &rnd);
+
+        randomize_tensor_normal(layer.ffn_norm, layer.ffn_norm->n_dims, layer.ffn_norm->ne, &rnd);
+
+        randomize_tensor_normal(layer.w1, layer.w1->n_dims, layer.w1->ne, &rnd);
+        randomize_tensor_normal(layer.w2, layer.w2->n_dims, layer.w2->ne, &rnd);
+        randomize_tensor_normal(layer.w3, layer.w3->n_dims, layer.w3->ne, &rnd);
+    }
+}
+
+
+void randomize_model_lora(struct llama_model_lora * model, int seed, float mean, float std, float min, float max) {
+    const auto & hparams = model->hparams;
+
+    const uint32_t n_layer = hparams.n_layer;
+
+    struct random_normal_distribution rnd;
+    init_random_normal_distribution(&rnd, seed, mean, std, min, max);
+    randomize_tensor_normal(model->tok_embeddings, model->tok_embeddings->n_dims, model->tok_embeddings->ne, &rnd);
+    randomize_tensor_normal(model->norm,           model->norm->n_dims,           model->norm->ne,           &rnd);
+    randomize_tensor_normal(model->outputa,        model->outputa->n_dims,        model->outputa->ne,         &rnd);
+    randomize_tensor_normal(model->outputb,        model->outputb->n_dims,        model->outputb->ne,         &rnd);
+
+    for (uint32_t i = 0; i < n_layer; ++i) {
+        auto & layer = model->layers[i];
+        randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd);
+
+        randomize_tensor_normal(layer.wqa, layer.wqa->n_dims, layer.wqa->ne, &rnd);
+        randomize_tensor_normal(layer.wqb, layer.wqb->n_dims, layer.wqb->ne, &rnd);
+        randomize_tensor_normal(layer.wka, layer.wka->n_dims, layer.wka->ne, &rnd);
+        randomize_tensor_normal(layer.wkb, layer.wkb->n_dims, layer.wkb->ne, &rnd);
+        randomize_tensor_normal(layer.wva, layer.wva->n_dims, layer.wva->ne, &rnd);
+        randomize_tensor_normal(layer.wvb, layer.wvb->n_dims, layer.wvb->ne, &rnd);
+        randomize_tensor_normal(layer.woa, layer.woa->n_dims, layer.woa->ne, &rnd);
+        randomize_tensor_normal(layer.wob, layer.wob->n_dims, layer.wob->ne, &rnd);
+
+        randomize_tensor_normal(layer.ffn_norm, layer.ffn_norm->n_dims, layer.ffn_norm->ne, &rnd);
+
+        randomize_tensor_normal(layer.w1, layer.w1->n_dims, layer.w1->ne, &rnd);
+        randomize_tensor_normal(layer.w2, layer.w2->n_dims, layer.w2->ne, &rnd);
+        randomize_tensor_normal(layer.w3, layer.w3->n_dims, layer.w3->ne, &rnd);
+    }
+}
+
+bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model, int n_batch) {
+    const auto & hparams = model->hparams;
+
+    const uint32_t n_ctx   = hparams.n_ctx;
+    const uint32_t n_embd  = hparams.n_embd;
+    const uint32_t n_layer = hparams.n_layer;
+
+    const int64_t n_mem      = n_layer*n_ctx*n_batch;
+    const int64_t n_elements = n_embd*n_mem;
+
+    // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
+
+    // struct ggml_init_params params;
+    // params.mem_size   = cache.buf.size;
+    // params.mem_buffer = cache.buf.addr;
+    // params.no_alloc   = false;
+    if (!cache->ctx) {
+        struct ggml_init_params params;
+        params.mem_size   = 2u*n_elements*ggml_type_size(GGML_TYPE_F32) + 2u*1024*1024;
+        params.mem_buffer = NULL;
+        params.no_alloc   = false;
+
+        cache->ctx = ggml_init(params);
+
+        if (!cache->ctx) {
+            fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
+            return false;
+        }
+    }
+
+    cache->k = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements);
+    cache->v = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements);
+
+    return true;
+}
+
+bool init_kv_cache_lora(struct llama_kv_cache* cache, struct llama_model_lora * model, int n_batch) {
+    const auto & hparams = model->hparams;
+
+    const uint32_t n_ctx   = hparams.n_ctx;
+    const uint32_t n_embd  = hparams.n_embd;
+    const uint32_t n_layer = hparams.n_layer;
+
+    const int64_t n_mem      = n_layer*n_ctx*n_batch;
+    const int64_t n_elements = n_embd*n_mem;
+
+    // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
+
+    // struct ggml_init_params params;
+    // params.mem_size   = cache.buf.size;
+    // params.mem_buffer = cache.buf.addr;
+    // params.no_alloc   = false;
+    if (!cache->ctx) {
+        struct ggml_init_params params;
+        params.mem_size   = 2u*n_elements*ggml_type_size(GGML_TYPE_F32) + 2u*1024*1024;
+        params.mem_buffer = NULL;
+        params.no_alloc   = false;
+
+        cache->ctx = ggml_init(params);
+
+        if (!cache->ctx) {
+            fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
+            return false;
+        }
+    }
+
+    cache->k = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements);
+    cache->v = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements);
+
+    return true;
+}
+
+struct ggml_tensor * forward(
+        struct llama_model    * model,
+        struct llama_kv_cache * cache,
+        struct ggml_context   * ctx0,
+        struct ggml_cgraph    * gf,
+        struct ggml_tensor    * tokens_input,
+        const  int              n_tokens,
+        const  int              n_past) {
+
+    const int N = n_tokens;
+
+    struct llama_kv_cache& kv_self = *cache;
+    const auto & hparams = model->hparams;
+    const int n_ctx   = hparams.n_ctx;
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+    const int n_head  = hparams.n_head;
+    const int n_rot   = hparams.n_rot;
+
+    struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    memcpy(tokens->data, tokens_input->data, N*ggml_element_size(tokens));
+
+    struct ggml_tensor * kc = kv_self.k;
+    struct ggml_tensor * vc = kv_self.v;
+
+    // inpL shape [n_embd,N,1,1]
+    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * inpSA = inpL;
+
+        struct ggml_tensor * cur;
+
+        // lctx.use_buf(ctx0, 0);
+
+        // norm
+        {
+            // cur shape [n_embd,N,1,1]
+            cur = ggml_rms_norm(ctx0, inpL);
+
+            // cur = attention_norm*cur
+            cur = ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model->layers[il].attention_norm, cur),
+                        cur);
+        }
+
+        // self-attention
+        {
+            // compute Q and K and RoPE them
+            // wq   shape [n_embd, n_embd, 1, 1]
+            // wk   shape [n_embd, n_embd, 1, 1]
+            // Qcur shape [n_embd/n_head, n_head, N, 1]
+            // Kcur shape [n_embd/n_head, n_head, N, 1]
+            struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+
+            // store key and value to memory
+            {
+                // compute the transposed [N, n_embd] V matrix
+                // wv   shape [n_embd, n_embd, 1, 1]
+                // Vcur shape [n_embd, N, 1, 1]
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wv, cur), n_embd, N)));
+
+                // kv_self.k shape [n_embd * n_ctx * n_layer, 1]
+                // kv_self.v shape [n_embd * n_ctx * n_layer, 1]
+                // k         shape [n_embd * N, 1]   == kv_self.k[:,n_past:n_past+N,il,0]
+                // v         shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0]
+
+                /* {
+                    struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+                    struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
+                            (   n_ctx)*ggml_element_size(kv_self.v),
+                            (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
+
+                    // important: storing RoPE-ed version of K in the KV cache!
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+                } //*/
+
+                kc = ggml_set_1d(ctx0, kc, ggml_reshape_1d(ctx0, Kcur, n_embd*N), (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+                vc = ggml_set_2d(ctx0, vc, Vcur, (   n_ctx)*ggml_element_size(kv_self.v),
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
+            }
+
+            // Qcur shape [n_embd/n_head, n_head, N, 1]
+            // Q shape    [n_embd/n_head, N, n_head, 1]
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        Qcur,
+                        0, 2, 1, 3);
+
+            // kv_self.k shape [n_embd * n_ctx * n_layer, 1]
+            // K shape [n_embd/n_head, n_past + N, n_head, 1]
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, kc, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kc)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
+                        0, 2, 1, 3);
+
+            // K * Q
+            // KQ shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            // KQ_scaled shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0,
+                        KQ,
+                        ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
+
+            // KQ_masked = mask_past(KQ_scaled)
+            // KQ_masked shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            // KQ_soft_max shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+            // split cached V into n_head heads
+            //// V shape [n_past + N, n_embd/n_head, n_head, 1]
+            // V shape [n_past + N, n_embd/n_head, n_head, 1] == kv_self.v[:,:(n_past+N),il,1]
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, vc,
+                        n_past + N, n_embd/n_head, n_head,
+                        n_ctx*ggml_element_size(vc),
+                        n_ctx*ggml_element_size(vc)*n_embd/n_head,
+                        il*n_ctx*ggml_element_size(vc)*n_embd);
+
+            // KQV shape [n_embd/n_head, N, n_head, 1]
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            // KQV_merged shape [n_embd/n_head, n_head, N, 1]
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            // KQV_merged shape
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            // cur shape [n_embd,N,1,1]
+            cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N);
+            // cur = ggml_cpy(ctx0,
+            //         KQV_merged,
+            //         ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+
+            // projection (no bias)
+            // cur shape [n_embd,N,1,1]
+            cur = ggml_mul_mat(ctx0,
+                    model->layers[il].wo,
+                    cur);
+        }
+
+        // lctx.use_buf(ctx0, 1);
+
+        // inpFF shape [n_embd,N,1,1]
+        struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
+
+        // feed-forward network
+        {
+            // norm
+            {
+                // cur shape [n_embd,N,1,1]
+                cur = ggml_rms_norm(ctx0, inpFF);
+
+                // cur = ffn_norm*cur
+                // cur shape [n_embd,N,1,1]
+                cur = ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model->layers[il].ffn_norm, cur),
+                        cur);
+            }
+
+            // tmp shape [n_ff,N,1,1]
+            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
+                    model->layers[il].w3,
+                    cur);
+
+            // cur shape [n_ff,N,1,1]
+            cur = ggml_mul_mat(ctx0,
+                    model->layers[il].w1,
+                    cur);
+
+            // SILU activation
+            // cur shape [n_ff,N,1,1]
+            cur = ggml_silu(ctx0, cur);
+
+            // cur shape [n_ff,N,1,1]
+            cur = ggml_mul(ctx0, cur, tmp);
+
+            // cur shape [n_embd,N,1,1]
+            cur = ggml_mul_mat(ctx0,
+                    model->layers[il].w2,
+                    cur);
+        }
+
+        // cur shape [n_embd,N,1,1]
+        cur = ggml_add(ctx0, cur, inpFF);
+
+        // input for next layer
+        // inpL shape [n_embd,N,1,1]
+        inpL = cur;
+    }
+
+    // norm
+    {
+
+        // inpL shape [n_embd,N,1,1]
+        inpL = ggml_rms_norm(ctx0, inpL);
+
+        // inpL = norm*inpL
+        // inpL shape [n_embd,N,1,1]
+        inpL = ggml_mul(ctx0,
+                    ggml_repeat(ctx0, model->norm, inpL),
+                    inpL);
+
+        //embeddings = inpL;
+    }
+
+    // lm_head
+    // inpL shape [n_vocab,N,1,1]
+    inpL = ggml_mul_mat(ctx0, model->output, inpL);
+
+    // run the computation
+    ggml_build_forward_expand(gf, inpL);
+
+    return inpL;
+}
+
+void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
+    GGML_ASSERT(tensor->n_dims == 1);
+    GGML_ASSERT(tensor->ne[0] == ne0);
+}
+
+void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
+    GGML_ASSERT(tensor->n_dims == 2);
+    GGML_ASSERT(tensor->ne[0] == ne0);
+    GGML_ASSERT(tensor->ne[1] == ne1);
+}
+
+void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
+    GGML_ASSERT(tensor->n_dims == 3);
+    GGML_ASSERT(tensor->ne[0] == ne0);
+    GGML_ASSERT(tensor->ne[1] == ne1);
+    GGML_ASSERT(tensor->ne[2] == ne2);
+}
+
+void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
+    GGML_ASSERT(tensor->n_dims == 4);
+    GGML_ASSERT(tensor->ne[0] == ne0);
+    GGML_ASSERT(tensor->ne[1] == ne1);
+    GGML_ASSERT(tensor->ne[2] == ne2);
+    GGML_ASSERT(tensor->ne[3] == ne3);
+}
+
+struct ggml_tensor * forward_batch(
+        struct llama_model    * model,
+        struct llama_kv_cache * cache,
+        struct ggml_context   * ctx0,
+        struct ggml_cgraph    * gf,
+        struct ggml_tensor    * tokens_input,
+        const  int              n_tokens,
+        const  int              n_past,
+        const  int              n_batch) {
+
+    const int N = n_tokens;
+
+    struct llama_kv_cache& kv_self = *cache;
+    const auto & hparams = model->hparams;
+    const int n_ctx   = hparams.n_ctx;
+    const int n_vocab = hparams.n_vocab;
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+    const int n_head  = hparams.n_head;
+    const int n_rot   = hparams.n_rot;
+    const int n_ff    = get_n_ff(&hparams);
+
+    struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
+    memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
+
+    struct ggml_tensor * kc = kv_self.k;
+    struct ggml_tensor * vc = kv_self.v;
+
+    // inpL shape [n_embd,N*n_batch,1]
+    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
+    assert_shape_2d(inpL, n_embd, N*n_batch);
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * inpSA = inpL;
+
+        struct ggml_tensor * cur;
+
+        // lctx.use_buf(ctx0, 0);
+
+        // norm
+        {
+            // cur shape [n_embd,N*n_batch,1,1]
+            cur = ggml_rms_norm(ctx0, inpL);
+            assert_shape_2d(cur, n_embd, N*n_batch);
+
+            // cur = attention_norm*cur
+            cur = ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model->layers[il].attention_norm, cur),
+                        cur);
+            assert_shape_2d(cur, n_embd, N*n_batch);
+        }
+
+        // self-attention
+        {
+            // compute Q and K and RoPE them
+            // wq   shape [n_embd, n_embd, 1, 1]
+            // wk   shape [n_embd, n_embd, 1, 1]
+            // Qcur shape [n_embd/n_head, n_head, N, n_batch]
+            // Kcur shape [n_embd/n_head, n_head, N, n_batch]
+            struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
+            struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
+            assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
+            assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
+
+            // store key and value to memory
+            {
+                // compute the transposed [N, n_embd] V matrix
+                // wv   shape [n_embd, n_embd, 1, 1]
+                // Vcur shape [N, n_embd, n_batch, 1]
+                struct ggml_tensor * Vcur = ggml_cont(ctx0,
+                    ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_mul_mat(ctx0,
+                                model->layers[il].wv,
+                                cur),
+                        n_embd, N, n_batch),
+                        1, 0, 2, 3));
+
+                assert_shape_3d(Vcur, N, n_embd, n_batch);
+
+                // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer]
+                // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer]
+                // k         shape [n_embd * N, n_batch]   == kv_self.k[:,n_past:n_past+N,:,il]
+                // v         shape [N, n_embd, n_batch, 1] == kv_self.v[:,n_past:n_past+N,:,il]
+
+                /* {
+                    struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+                    struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
+                            (   n_ctx)*ggml_element_size(kv_self.v),
+                            (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
+
+                    // important: storing RoPE-ed version of K in the KV cache!
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+                } //*/
+
+                kc = ggml_set_2d(ctx0, kc,
+                        ggml_reshape_2d(ctx0, Kcur, n_embd*N, n_batch),
+                        ggml_element_size(kc)*n_embd*n_ctx,
+                        (ggml_element_size(kc)*n_embd)*(il*n_batch*n_ctx + n_past));
+                vc = ggml_set_2d(ctx0, vc,
+                        ggml_reshape_2d(ctx0, Vcur, N*n_embd, n_batch),
+                        ggml_element_size(vc)*n_ctx*n_embd,
+                        ggml_element_size(vc)*(n_past + il*n_embd*n_batch*n_ctx));
+
+                assert_shape_1d(kc, n_embd * n_ctx * n_batch * n_layer);
+                assert_shape_1d(vc, n_embd * n_ctx * n_batch * n_layer);
+            }
+
+            // Qcur shape [n_embd/n_head, n_head, N, n_batch]
+            // Q shape    [n_embd/n_head, N, n_head, n_batch]
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        Qcur,
+                        0, 2, 1, 3);
+            assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch);
+
+            // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer]
+            // K shape [n_embd/n_head, n_past + N, n_head, n_batch]
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_reshape_4d(ctx0,
+                            ggml_view_3d(ctx0,
+                                kc,
+                                n_embd,
+                                (n_past + N),
+                                n_batch,
+                                n_embd*ggml_element_size(kc),
+                                n_ctx*n_embd*ggml_element_size(kc),
+                                il*n_batch*n_ctx*n_embd*ggml_element_size(kc)),
+                            n_embd/n_head, n_head, n_past + N, n_batch),
+                        0, 2, 1, 3);
+            assert_shape_4d(K, n_embd/n_head, n_past + N, n_head, n_batch);
+
+            // K * Q
+            // KQ shape [n_past + N, N, n_head, n_batch]
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            assert_shape_4d(KQ, n_past + N, N, n_head, n_batch);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            // KQ_scaled shape [n_past + N, N, n_head, n_batch]
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0,
+                        KQ,
+                        ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
+            assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch);
+
+            // KQ_masked = mask_past(KQ_scaled)
+            // KQ_masked shape [n_past + N, N, n_head, n_batch]
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+            assert_shape_4d(KQ_masked, n_past + N, N, n_head, n_batch);
+
+            // KQ = soft_max(KQ_masked)
+            // KQ_soft_max shape [n_past + N, N, n_head, n_batch]
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+            assert_shape_4d(KQ_soft_max, n_past + N, N, n_head, n_batch);
+
+            // split cached V into n_head heads
+            // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer]
+            // V shape [n_past + N, n_embd/n_head, n_head, n_batch] == kv_self.v[:(n_past+N),:,:,il]
+            struct ggml_tensor * V =
+                ggml_view_4d(ctx0, vc,
+                        n_past + N, n_embd/n_head, n_head, n_batch,
+                        ggml_element_size(vc)*n_ctx,
+                        ggml_element_size(vc)*n_ctx*n_embd/n_head,
+                        ggml_element_size(vc)*n_ctx*n_embd,
+                        il*n_batch*n_ctx*n_embd*ggml_element_size(vc));
+            assert_shape_4d(V, n_past + N, n_embd/n_head, n_head, n_batch);
+
+            // KQV shape [n_embd/n_head, N, n_head, n_batch]
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+            assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            // KQV_merged shape [n_embd/n_head, n_head, N, n_batch]
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch);
+            // KQV_merged shape
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            // cur shape [n_embd,N*n_batch,1,1]
+            cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch);
+            assert_shape_2d(cur, n_embd, N*n_batch);
+            // cur = ggml_cpy(ctx0,
+            //         KQV_merged,
+            //         ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+
+            // projection (no bias)
+            // cur shape [n_embd,N*n_batch,1,1]
+            cur = ggml_mul_mat(ctx0,
+                    model->layers[il].wo,
+                    cur);
+            assert_shape_2d(cur, n_embd, N*n_batch);
+        }
+
+        // lctx.use_buf(ctx0, 1);
+
+        // inpFF shape [n_embd,N*n_batch,1,1]
+        struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
+        assert_shape_2d(inpFF, n_embd, N*n_batch);
+
+        // feed-forward network
+        {
+            // norm
+            {
+                // cur shape [n_embd,N*n_batch,1,1]
+                cur = ggml_rms_norm(ctx0, inpFF);
+                assert_shape_2d(cur, n_embd, N*n_batch);
+
+                // cur = ffn_norm*cur
+                // cur shape [n_embd,N*n_batch,1,1]
+                cur = ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model->layers[il].ffn_norm, cur),
+                        cur);
+                assert_shape_2d(cur, n_embd, N*n_batch);
+            }
+
+            // tmp shape [n_ff,N*n_batch,1,1]
+            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
+                    model->layers[il].w3,
+                    cur);
+            assert_shape_2d(tmp, n_ff, N*n_batch);
+
+            // cur shape [n_ff,N*n_batch,1,1]
+            cur = ggml_mul_mat(ctx0,
+                    model->layers[il].w1,
+                    cur);
+            assert_shape_2d(cur, n_ff, N*n_batch);
+
+            // SILU activation
+            // cur shape [n_ff,N*n_batch,1,1]
+            cur = ggml_silu(ctx0, cur);
+            assert_shape_2d(cur, n_ff, N*n_batch);
+
+            // cur shape [n_ff,N*n_batch,1,1]
+            cur = ggml_mul(ctx0, cur, tmp);
+            assert_shape_2d(cur, n_ff, N*n_batch);
+
+            // cur shape [n_embd,N*n_batch,1,1]
+            cur = ggml_mul_mat(ctx0,
+                    model->layers[il].w2,
+                    cur);
+            assert_shape_2d(cur, n_embd, N*n_batch);
+        }
+
+        // cur shape [n_embd,N*n_batch,1,1]
+        cur = ggml_add(ctx0, cur, inpFF);
+        assert_shape_2d(cur, n_embd, N*n_batch);
+
+        // input for next layer
+        // inpL shape [n_embd,N*n_batch,1,1]
+        inpL = cur;
+        assert_shape_2d(inpL, n_embd, N*n_batch);
+    }
+
+    // norm
+    {
+
+        // inpL shape [n_embd,N*n_batch,1,1]
+        inpL = ggml_rms_norm(ctx0, inpL);
+        assert_shape_2d(inpL, n_embd, N*n_batch);
+
+        // inpL = norm*inpL
+        // inpL shape [n_embd,N*n_batch,1,1]
+        inpL = ggml_mul(ctx0,
+                    ggml_repeat(ctx0, model->norm, inpL),
+                    inpL);
+
+        assert_shape_2d(inpL, n_embd, N*n_batch);
+
+        //embeddings = inpL;
+    }
+
+    // lm_head
+    // inpL shape [n_vocab,N*n_batch,1,1]
+    inpL = ggml_mul_mat(ctx0, model->output, inpL);
+    assert_shape_2d(inpL, n_vocab, N*n_batch);
+
+    {
+        // inpL shape [n_vocab,N,n_batch,1]
+        inpL = ggml_reshape_3d(ctx0,
+                        inpL,
+                        n_vocab, N, n_batch);
+        assert_shape_3d(inpL, n_vocab, N, n_batch);
+    }
+
+    // run the computation
+    ggml_build_forward_expand(gf, inpL);
+
+    return inpL;
+}
+
+
+struct ggml_tensor * forward_lora(
+        struct llama_model_lora * model,
+        struct llama_kv_cache   * cache,
+        struct ggml_context     * ctx0,
+        struct ggml_cgraph      * gf,
+        struct ggml_tensor      * tokens_input,
+        const  int                n_tokens,
+        const  int                n_past) {
+
+    const int N = n_tokens;
+
+    struct llama_kv_cache& kv_self = *cache;
+    const auto & hparams = model->hparams;
+
+    const int n_ctx   = hparams.n_ctx;
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+    const int n_head  = hparams.n_head;
+    const int n_rot   = hparams.n_rot;
+
+    struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    memcpy(tokens->data, tokens_input->data, N*ggml_element_size(tokens));
+
+    struct ggml_tensor * kc = kv_self.k;
+    struct ggml_tensor * vc = kv_self.v;
+
+    // inpL shape [n_embd,N,1,1]
+    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * inpSA = inpL;
+
+        struct ggml_tensor * cur;
+
+        // norm
+        {
+            // cur shape [n_embd,N,1,1]
+            cur = ggml_rms_norm(ctx0, inpL);
+
+            // cur = attention_norm*cur
+            cur = ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model->layers[il].attention_norm, cur),
+                        cur);
+        }
+
+        // self-attention
+        {
+            // compute Q and K and RoPE them
+            // wq   shape [n_embd, n_embd, 1, 1]
+            // wk   shape [n_embd, n_embd, 1, 1]
+            // Qcur shape [n_embd/n_head, n_head, N, 1]
+            // Kcur shape [n_embd/n_head, n_head, N, 1]
+            struct ggml_tensor * Qcur = ggml_rope(ctx0,
+                                            ggml_reshape_3d(ctx0,
+                                                ggml_mul_mat(ctx0,
+                                                    model->layers[il].wqa,
+                                                    ggml_mul_mat(ctx0,
+                                                        model->layers[il].wqb,
+                                                        cur)),
+                                                n_embd/n_head, n_head, N),
+                                            n_past, n_rot, 0);
+            struct ggml_tensor * Kcur = ggml_rope(ctx0,
+                                            ggml_reshape_3d(ctx0,
+                                                ggml_mul_mat(ctx0,
+                                                    model->layers[il].wka,
+                                                    ggml_mul_mat(ctx0,
+                                                        model->layers[il].wkb,
+                                                        cur)),
+                                                n_embd/n_head, n_head, N),
+                                            n_past, n_rot, 0);
+
+            // store key and value to memory
+            {
+                // compute the transposed [N, n_embd] V matrix
+                // wv   shape [n_embd, n_embd, 1, 1]
+                // Vcur shape [n_embd, N, 1, 1]
+                struct ggml_tensor * Vcur = ggml_cont(ctx0,
+                                                ggml_transpose(ctx0,
+                                                    ggml_reshape_2d(ctx0,
+                                                        ggml_mul_mat(ctx0,
+                                                            model->layers[il].wva,
+                                                            ggml_mul_mat(ctx0,
+                                                                model->layers[il].wvb,
+                                                                cur)),
+                                                        n_embd, N)));
+
+                // kv_self.k shape [n_embd * n_ctx * n_layer, 1]
+                // kv_self.v shape [n_embd * n_ctx * n_layer, 1]
+                // k         shape [n_embd * N, 1]   == kv_self.k[:,n_past:n_past+N,il,0]
+                // v         shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0]
+
+                /* {
+                    struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+                    struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
+                            (   n_ctx)*ggml_element_size(kv_self.v),
+                            (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
+
+                    // important: storing RoPE-ed version of K in the KV cache!
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+                } //*/
+
+                kc = ggml_set_1d(ctx0, kc, ggml_reshape_1d(ctx0, Kcur, n_embd*N), (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+                vc = ggml_set_2d(ctx0, vc, Vcur, (   n_ctx)*ggml_element_size(kv_self.v),
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
+            }
+
+            // Qcur shape [n_embd/n_head, n_head, N, 1]
+            // Q shape    [n_embd/n_head, N, n_head, 1]
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        Qcur,
+                        0, 2, 1, 3);
+
+            // kv_self.k shape [n_embd * n_ctx * n_layer, 1]
+            // K shape [n_embd/n_head, n_past + N, n_head, 1]
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, kc, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kc)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
+                        0, 2, 1, 3);
+
+            // K * Q
+            // KQ shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            // KQ_scaled shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0,
+                        KQ,
+                        ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
+
+            // KQ_masked = mask_past(KQ_scaled)
+            // KQ_masked shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            // KQ_soft_max shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+            // split cached V into n_head heads
+            //// V shape [n_past + N, n_embd/n_head, n_head, 1]
+            // V shape [n_past + N, n_embd/n_head, n_head, 1] == kv_self.v[:,:(n_past+N),il,1]
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, vc,
+                        n_past + N, n_embd/n_head, n_head,
+                        n_ctx*ggml_element_size(vc),
+                        n_ctx*ggml_element_size(vc)*n_embd/n_head,
+                        il*n_ctx*ggml_element_size(vc)*n_embd);
+
+            // KQV shape [n_embd/n_head, N, n_head, 1]
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            // KQV_merged shape [n_embd/n_head, n_head, N, 1]
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            // KQV_merged shape
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            // cur shape [n_embd,N,1,1]
+            cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N);
+            // cur = ggml_cpy(ctx0,
+            //         KQV_merged,
+            //         ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+
+            // projection (no bias)
+            // cur shape [n_embd,N,1,1]
+            cur = ggml_mul_mat(ctx0,
+                    model->layers[il].woa,
+                    ggml_mul_mat(ctx0,
+                        model->layers[il].wob,
+                        cur));
+        }
+
+        // inpFF shape [n_embd,N,1,1]
+        struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
+
+        // feed-forward network
+        {
+            // norm
+            {
+                // cur shape [n_embd,N,1,1]
+                cur = ggml_rms_norm(ctx0, inpFF);
+
+                // cur = ffn_norm*cur
+                // cur shape [n_embd,N,1,1]
+                cur = ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model->layers[il].ffn_norm, cur),
+                        cur);
+            }
+
+            // tmp shape [n_ff,N,1,1]
+            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
+                    model->layers[il].w3,
+                    cur);
+
+            // cur shape [n_ff,N,1,1]
+            cur = ggml_mul_mat(ctx0,
+                    model->layers[il].w1,
+                    cur);
+
+            // SILU activation
+            // cur shape [n_ff,N,1,1]
+            cur = ggml_silu(ctx0, cur);
+
+            // cur shape [n_ff,N,1,1]
+            cur = ggml_mul(ctx0, cur, tmp);
+
+            // cur shape [n_embd,N,1,1]
+            cur = ggml_mul_mat(ctx0,
+                    model->layers[il].w2,
+                    cur);
+        }
+
+        // cur shape [n_embd,N,1,1]
+        cur = ggml_add(ctx0, cur, inpFF);
+
+        // input for next layer
+        // inpL shape [n_embd,N,1,1]
+        inpL = cur;
+    }
+
+    // norm
+    {
+
+        // inpL shape [n_embd,N,1,1]
+        inpL = ggml_rms_norm(ctx0, inpL);
+
+        // inpL = norm*inpL
+        // inpL shape [n_embd,N,1,1]
+        inpL = ggml_mul(ctx0,
+                    ggml_repeat(ctx0, model->norm, inpL),
+                    inpL);
+
+        //embeddings = inpL;
+    }
+
+
+    // lm_head
+    // inpL shape [n_vocab,N,1,1]
+    inpL = ggml_mul_mat(ctx0,
+                model->outputa,
+                    ggml_mul_mat(ctx0,
+                        model->outputb,
+                        inpL));
+
+    // ggml_set_scratch(ctx0, { 0, 0, nullptr, });
+    // run the computation
+    ggml_build_forward_expand(gf, inpL);
+
+    return inpL;
+}
+
+void sample_softmax(struct ggml_tensor * logits, struct ggml_tensor * probs, struct ggml_tensor * best_samples) {
+    assert(logits->n_dims == 2);
+    assert(probs->n_dims == 2);
+    assert(best_samples->n_dims == 1);
+    assert(logits->ne[1] == best_samples->ne[0]);
+    assert(logits->ne[0] == probs->ne[0]);
+    assert(logits->ne[1] == probs->ne[1]);
+    for (int i = 0; i < logits->ne[1]; ++i) {
+        float max_logit = ggml_get_f32_1d(logits, i * logits->ne[0]);
+        ggml_set_i32_1d(best_samples, i, 0);
+        for (int k = 0; k < logits->ne[0]; ++k) {
+            float logit = ggml_get_f32_1d(logits, i * logits->ne[0] + k);
+            if (logit > max_logit) {
+                max_logit = logit;
+                ggml_set_i32_1d(best_samples, i, k);
+            }
+        }
+        float psum = 0;
+        for (int k = 0; k < logits->ne[0]; ++k) {
+            float logit = ggml_get_f32_1d(logits, i * logits->ne[0] + k);
+            float p = (logit == -INFINITY) ? 0 : expf(logit - max_logit);
+            psum += p;
+            ggml_set_f32_1d(probs, i * probs->ne[0] + k, p);
+        }
+        for (int k = 0; k < logits->ne[0]; ++k) {
+            float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k);
+            ggml_set_f32_1d(probs, i * probs->ne[0] + k, p / psum);
+        }
+    }
+}
+
+void sample_softmax_batch(struct ggml_context * ctx, struct ggml_tensor * logits, struct ggml_tensor * probs, struct ggml_tensor * best_samples) {
+    GGML_ASSERT(best_samples->n_dims == 2);
+    GGML_ASSERT(logits->n_dims == 3);
+    GGML_ASSERT(probs->n_dims == 3);
+    int n_tokens = best_samples->ne[0];
+    int n_batch  = best_samples->ne[1];
+    int n_vocab  = logits->ne[0];
+    GGML_ASSERT(n_tokens == logits->ne[1]);
+    GGML_ASSERT(n_batch  == logits->ne[2]);
+    GGML_ASSERT(n_vocab  == probs->ne[0]);
+    GGML_ASSERT(n_tokens == probs->ne[1]);
+    GGML_ASSERT(n_batch  == probs->ne[2]);
+
+    for (int k = 0; k < n_batch; ++k) {
+        struct ggml_tensor * best_samples_k = ggml_view_1d(ctx,
+                                                best_samples,
+                                                best_samples->ne[0],
+                                                k*best_samples->nb[1]);
+        struct ggml_tensor * logits_k       = ggml_view_2d(ctx,
+                                                logits,
+                                                logits->ne[0],
+                                                logits->ne[1],
+                                                logits->nb[1],
+                                                k*logits->nb[2]);
+        struct ggml_tensor * probs_k        = ggml_view_2d(ctx,
+                                                probs,
+                                                probs->ne[0],
+                                                probs->ne[1],
+                                                probs->nb[1],
+                                                k*probs->nb[2]);
+        sample_softmax(logits_k, probs_k, best_samples_k);
+    }
+}
+
+void print_row(struct ggml_tensor * probs, int i) {
+    for (int k = 0; k < probs->ne[0]; ++k) {
+        float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k);
+        printf(" %.2f", p);
+    }
+    printf("\n");
+}
+
+void print_matrix(struct ggml_tensor * probs) {
+    assert(probs->n_dims == 2);
+    for (int i = 0; i < probs->ne[1]; ++i) {
+        for (int k = 0; k < probs->ne[0]; ++k) {
+            float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k);
+            printf(" %.2f", p);
+        }
+        printf("\n");
+    }
+}
+
+void print_token(int token, int n_vocab) {
+    for (int k = 0; k < token; ++k) {
+        printf(" ");
+    }
+    printf("X");
+    for (int k = token+1; k < n_vocab; ++k) {
+        printf(" ");
+    }
+    printf("\n");
+}
+
+void print_tokens(struct ggml_tensor * tokens, int n_vocab) {
+    for (int i=0; i<tokens->ne[0]; ++i) {
+        int token = ggml_get_i32_1d(tokens, i);
+        print_token(token, n_vocab);
+    }
+}
+
+void get_example_targets(int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) {
+    int n_tokens = tokens_input->ne[0];
+    int n_vocab = targets->ne[0];
+    float randomness = 0.0f;
+    // ggml_set_zero(targets);
+    ggml_set_f32(targets, -1.0f);
+    ggml_set_i32_1d(tokens_input, 0, 0);
+    for (int i=1; i<n_tokens+1; ++i) {
+        float x = example_id + i * 3.14159f * 2.0f * 1.0f * 0.5f / n_tokens;
+        float y = sinf(x);//*cosf(x*1.1f+1.0f);
+        float z = (y+1.0f)*0.5f; // scale to [0..1]
+        z += (frand()-0.5f)*(randomness/n_vocab);
+        z = (z < 0.0f) ? 0.0f : (z > 1.0f) ? 1.0f : z; // clamp to [0..1]
+        int token = std::max(1,std::min(1+(int)(z*(float)(n_vocab-1)), n_vocab-1));
+        ggml_set_f32_1d(targets, (i-1)*n_vocab + token, +1.0f);
+        if (i<n_tokens) {
+            ggml_set_i32_1d(tokens_input, i, token);
+        }
+    }
+}
+
+void get_example_targets_batch(struct ggml_context * ctx, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) {
+    GGML_ASSERT(tokens_input->n_dims == 2);
+    GGML_ASSERT(     targets->n_dims == 3);
+    int n_tokens = tokens_input->ne[0];
+    int n_batch  = tokens_input->ne[1];
+    GGML_ASSERT(n_tokens == targets->ne[1]);
+    GGML_ASSERT(n_batch  == targets->ne[2]);
+
+    for (int k=0; k<n_batch; ++k) {
+        struct ggml_tensor * tokens_input_k = ggml_view_1d(ctx,
+                                                tokens_input,
+                                                tokens_input->ne[0],
+                                                k*tokens_input->nb[1]);
+        struct ggml_tensor * targets_k    = ggml_view_2d(ctx,
+                                                targets,
+                                                targets->ne[0],
+                                                targets->ne[1],
+                                                targets->nb[1],
+                                                k*targets->nb[2]);
+        get_example_targets(example_id*n_batch + k, tokens_input_k, targets_k);
+    }
+}
+
+void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * targets, int n_shift) {
+    int n_tokens = tokens_input->ne[0];
+    int n_vocab = targets->ne[0];
+    for (int i=0; i<n_tokens-n_shift; ++i) {
+        ggml_set_i32_1d(tokens_input, i, ggml_get_i32_1d(tokens_input, i + n_shift));
+        for (int k=0; k<n_vocab; ++k) {
+            ggml_set_f32_1d(targets, i*n_vocab + k, ggml_get_f32_1d(targets, (i + n_shift)*n_vocab + k));
+        }
+    }
+}
+
+struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
+    // todo: instead of a-b: a[1:]-b[:-1]
+    return ggml_sum(ctx, ggml_sqr(ctx, ggml_sub(ctx, a, b)));
+}
+
+struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
+    const float eps = 1e-3;
+    return
+        ggml_sum(ctx,
+            ggml_neg(ctx,
+                ggml_sum_rows(ctx,
+                    ggml_mul(ctx,
+                        ggml_soft_max(ctx, a),
+                        ggml_log(ctx,
+                            ggml_add1(ctx,
+                                ggml_soft_max(ctx, b),
+                                ggml_new_f32(ctx, eps)))))));
+}
+
+int main(int argc, char ** argv) {
+    if (argc < 1) {
+        fprintf(stderr, "usage: %s\n", argv[0]);
+
+        return 1;
+    }
+
+    struct ggml_init_params lcparams;
+    lcparams.mem_size   = 1024ll*1024ll*1024ll;
+    lcparams.mem_buffer = NULL;
+    lcparams.no_alloc   = false;
+
+    struct llama_model model;
+    model.hparams.n_vocab = 8;
+    model.hparams.n_ctx   = 8;
+    model.hparams.n_embd  = 32;
+    model.hparams.n_mult  = 2;
+    model.hparams.n_head  = 8;
+    model.hparams.n_layer = 1;
+    model.hparams.n_rot   = std::min(16u, model.hparams.n_embd / model.hparams.n_head);
+
+    // model.hparams.n_embd  = 32;
+    // model.hparams.n_mult  = 2;
+    // model.hparams.n_head  = 4;
+    // model.hparams.n_layer = 8;
+    // model.hparams.n_rot   = 8;
+
+    model.ctx = ggml_init(lcparams);
+    printf("init model\n");
+    init_model(&model);
+    set_param_model(&model);
+
+    randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
+
+/*
+    struct llama_model_lora model_lora;
+    // model.hparams.n_vocab = 6;
+    // model.hparams.n_ctx   = 64;
+    // model.hparams.n_embd  = 128;
+    // model.hparams.n_mult  = 2;
+    // model.hparams.n_head  = 8;
+    // model.hparams.n_layer = 6;
+    // model.hparams.n_rot   = model.hparams.n_embd / model.hparams.n_head;
+
+    model_lora.hparams.n_vocab = 16;
+    model_lora.hparams.n_ctx   = 32;
+    model_lora.hparams.n_embd  = 256;
+    model_lora.hparams.n_mult  = 2;
+    model_lora.hparams.n_head  = 16;
+    model_lora.hparams.n_layer = 1;
+    model_lora.hparams.n_lora  = 64;
+    model_lora.hparams.n_rot   = MIN(16, model_lora.hparams.n_embd / model_lora.hparams.n_head);
+    // model.hparams.n_rot   = (model.hparams.n_embd / model.hparams.n_head) / 2;
+
+    // model.hparams.n_embd  = 32;
+    // model.hparams.n_mult  = 2;
+    // model.hparams.n_head  = 4;
+    // model.hparams.n_layer = 8;
+    // model.hparams.n_rot   = 8;
+
+    model_lora.ctx = ggml_init(lcparams);
+    printf("init model_lora\n");
+    init_model_lora(&model_lora);
+    set_param_model_lora(&model_lora);
+
+    randomize_model_lora(&model_lora, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
+*/
+    int n_batch = 8;
+    // key + value cache for the self attention
+    struct llama_kv_cache kv_self;
+    printf("init_kv_cache\n");
+    kv_self.ctx = model.ctx;
+    init_kv_cache(&kv_self, &model, n_batch);
+    //init_kv_cache_lora(&kv_self, &model_lora);
+
+    size_t    compute_size = 1024ll*1024ll*1024ll;
+    uint8_t * compute_addr = new uint8_t[compute_size];
+
+    int n_examples = 256;
+    int n_tokens = model.hparams.n_ctx;
+    int n_vocab  = model.hparams.n_vocab;
+
+    for (int ex=0; ex<n_examples; ++ex) {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ compute_size,
+            /*.mem_buffer =*/ compute_addr,
+            /*.no_alloc   =*/ false,
+        };
+
+        struct ggml_context * ctx0 = ggml_init(params);
+
+        struct ggml_tensor * after_opt_best_samples  = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
+        struct ggml_tensor * after_opt_probs         = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
+        struct ggml_tensor * tokens_input            = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
+        struct ggml_tensor * targets                 = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
+
+        int n_past = 0;
+
+        ggml_cgraph gf = {};
+        gf.n_threads = 1;
+
+        get_example_targets_batch(ctx0, 64*ex+0,  tokens_input, targets);
+
+        struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
+        // struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits);
+        struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
+
+        ggml_build_forward_expand(&gf, e);
+        ggml_graph_compute(ctx0, &gf);
+
+        float error_before_opt = ggml_get_f32_1d(e, 0);
+
+        struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
+        struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
+        opt_params_adam.print_forward_graph = false;
+        opt_params_adam.print_backward_graph = false;
+        opt_params_lbfgs.print_forward_graph = false;
+        opt_params_lbfgs.print_backward_graph = false;
+        opt_params_adam.adam.n_iter = 16;
+        opt_params_lbfgs.lbfgs.n_iter = 16;
+        // ggml_opt(ctx0, opt_params_adam, e);
+        ggml_opt(ctx0, opt_params_lbfgs, e);
+        //
+        ggml_build_forward_expand(&gf, e);
+        ggml_graph_compute(ctx0, &gf);
+
+        float error_after_opt = ggml_get_f32_1d(e, 0);
+
+        if (ex % 8 == 0) {
+            printf("Example %d\n", (ex+1));
+            printf("error_before_opt: %.2f\n", error_before_opt);
+            printf("error_after_opt:  %.2f\n", error_after_opt);
+        }
+
+        if (ex % 64 == 0) {
+            sample_softmax_batch(ctx0, logits, after_opt_probs, after_opt_best_samples);
+            // printf("probabilities after optimization:\n");
+            // print_matrix(after_opt_probs);
+            printf("best samples after optimization:\n");
+            print_tokens(after_opt_best_samples, n_vocab);
+        }
+
+        ggml_free(ctx0);
+    }
+
+    {
+        int n_gen = 128;
+        int sample_ctx = n_tokens-n_tokens/8;
+
+        printf("Generating %d tokens.\n", n_gen);
+
+        struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * targets      = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
+
+        get_example_targets(137, tokens_input, targets);
+        for (int i=sample_ctx; i<n_tokens; ++i) {
+            ggml_set_i32_1d(tokens_input, i, n_vocab/2);
+        }
+
+        for (int i=0; i<sample_ctx-1; ++i) {
+            print_token(ggml_get_i32_1d(tokens_input, i), n_vocab);
+        }
+        printf("---\n");
+        for (int i=0; i<n_gen; ++i) {
+            struct ggml_init_params params = {
+                /*.mem_size   =*/ compute_size,
+                /*.mem_buffer =*/ compute_addr,
+                /*.no_alloc   =*/ false,
+            };
+            struct ggml_context * ctx0 = ggml_init(params);
+
+            ggml_cgraph gf = {};
+            gf.n_threads = 1;
+
+            int n_past = 0;
+            struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
+
+            ggml_build_forward_expand(&gf, logits);
+            ggml_graph_compute(ctx0, &gf);
+
+            struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
+            struct ggml_tensor * probs        = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
+
+            sample_softmax(logits, probs, best_samples);
+
+            // int sample_at = n_tokens-1;
+            int token = ggml_get_i32_1d(best_samples, sample_ctx-1);
+
+            // print_row(probs, sample_at);
+            print_token(token, n_vocab);
+
+            lshift_examples(tokens_input, targets, 1);
+            ggml_set_i32_1d(tokens_input, 0, 0);
+            ggml_set_i32_1d(tokens_input, sample_ctx-1, token);
+
+            ggml_free(ctx0);
+        }
+    }
+
+    print_matrix(model.tok_embeddings);
+
+    printf("done\n");
+    // ggml_free(kv_self.ctx);
+    // ggml_free(model_lora.ctx);
+    ggml_free(model.ctx);
+    return 0;
+}

Fișier diff suprimat deoarece este prea mare
+ 617 - 49
ggml.c


+ 193 - 7
ggml.h

@@ -192,7 +192,7 @@
 
 #define GGML_MAX_DIMS          4
 #define GGML_MAX_NODES         4096
-#define GGML_MAX_PARAMS        16
+#define GGML_MAX_PARAMS        256
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_OPT           4
 #define GGML_DEFAULT_N_THREADS 4
@@ -262,12 +262,16 @@ extern "C" {
 
         GGML_OP_DUP,
         GGML_OP_ADD,
+        GGML_OP_ADD1,
+        GGML_OP_ACC,
         GGML_OP_SUB,
         GGML_OP_MUL,
         GGML_OP_DIV,
         GGML_OP_SQR,
         GGML_OP_SQRT,
+        GGML_OP_LOG,
         GGML_OP_SUM,
+        GGML_OP_SUM_ROWS,
         GGML_OP_MEAN,
         GGML_OP_REPEAT,
         GGML_OP_ABS,
@@ -277,12 +281,15 @@ extern "C" {
         GGML_OP_RELU,
         GGML_OP_GELU,
         GGML_OP_SILU,
+        GGML_OP_SILU_BACK,
         GGML_OP_NORM, // normalize
         GGML_OP_RMS_NORM,
+        GGML_OP_RMS_NORM_BACK,
 
         GGML_OP_MUL_MAT,
 
         GGML_OP_SCALE,
+        GGML_OP_SET,
         GGML_OP_CPY,
         GGML_OP_CONT,
         GGML_OP_RESHAPE,
@@ -290,9 +297,13 @@ extern "C" {
         GGML_OP_PERMUTE,
         GGML_OP_TRANSPOSE,
         GGML_OP_GET_ROWS,
+        GGML_OP_GET_ROWS_BACK,
+        GGML_OP_DIAG,
         GGML_OP_DIAG_MASK_INF,
+        GGML_OP_DIAG_MASK_ZERO,
         GGML_OP_SOFT_MAX,
         GGML_OP_ROPE,
+        GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
         GGML_OP_CONV_1D_1S,
         GGML_OP_CONV_1D_2S,
@@ -496,6 +507,29 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_add1(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_acc(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_acc_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
     GGML_API struct ggml_tensor * ggml_sub(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -519,12 +553,24 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_log(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_log_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // return scalar
-    // TODO: compute sum along rows
     GGML_API struct ggml_tensor * ggml_sum(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]
+    GGML_API struct ggml_tensor * ggml_sum_rows(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // mean along rows
     GGML_API struct ggml_tensor * ggml_mean(
             struct ggml_context * ctx,
@@ -566,6 +612,13 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // a - x
+    // b - dy
+    GGML_API struct ggml_tensor * ggml_silu_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // normalize along rows
     // TODO: eps is hardcoded to 1e-5 for now
     GGML_API struct ggml_tensor * ggml_norm(
@@ -576,6 +629,13 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // a - x
+    // b - dy
+    GGML_API struct ggml_tensor * ggml_rms_norm_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // A: m rows, n columns
     // B: p rows, n columns (i.e. we transpose it internally)
     // result is m columns, p rows
@@ -588,12 +648,66 @@ extern "C" {
     // operations on tensors without backpropagation
     //
 
-    // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_scale(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_scale_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // b -> view(a,offset,nb1,nb2,3), return modified a
+    GGML_API struct ggml_tensor * ggml_set(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    // b -> view(a,offset,nb1,nb2,3), return view(a)
+    GGML_API struct ggml_tensor * ggml_set_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_set_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_set_1d_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                offset);
+
+    // b -> view(a,offset,nb1,nb2,3), return modified a
+    GGML_API struct ggml_tensor * ggml_set_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                offset);
+
+    // b -> view(a,offset,nb1,nb2,3), return view(a)
+    GGML_API struct ggml_tensor * ggml_set_2d_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                offset);
+
+
     // a -> b, return view(b)
     GGML_API struct ggml_tensor * ggml_cpy(
             struct ggml_context * ctx,
@@ -614,6 +728,11 @@ extern "C" {
 
     // return view(a)
     // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0);
+
     GGML_API struct ggml_tensor * ggml_reshape_2d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -629,6 +748,14 @@ extern "C" {
             int64_t               ne1,
             int64_t               ne2);
 
+    GGML_API struct ggml_tensor * ggml_reshape_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3);
+
     // offset in bytes
     GGML_API struct ggml_tensor * ggml_view_1d(
             struct ggml_context * ctx,
@@ -654,6 +781,18 @@ extern "C" {
             size_t                nb2, // slice stride in bytes
             size_t                offset);
 
+    GGML_API struct ggml_tensor * ggml_view_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3,
+            size_t                nb1, // row   stride in bytes
+            size_t                nb2, // slice stride in bytes
+            size_t                nb3,
+            size_t                offset);
+
     GGML_API struct ggml_tensor * ggml_permute(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -672,20 +811,50 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_get_rows_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c);
+
+    GGML_API struct ggml_tensor * ggml_diag(
+        struct ggml_context     * ctx,
+        struct ggml_tensor      * a);
+
     // set elements above the diagonal to -INF
-    // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_diag_mask_inf(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             int                   n_past);
 
     // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // set elements above the diagonal to 0
+    GGML_API struct ggml_tensor * ggml_diag_mask_zero(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * gml_diag_mask_zero_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
     GGML_API struct ggml_tensor * ggml_soft_max(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
-    // rotary position embedding
     // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_soft_max_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // rotary position embedding
     // if mode & 1 == 1, skip n_past elements
     // if mode & 2 == 1, GPT-NeoX style
     // TODO: avoid creating a new tensor every time
@@ -696,6 +865,23 @@ extern "C" {
             int                   n_dims,
             int                   mode);
 
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            int                   mode);
+
+    // rotary position embedding backward, i.e compute dx from dy
+    // a - dy
+    GGML_API struct ggml_tensor * ggml_rope_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            int                   mode);
+
     // alibi position embedding
     // in-place, returns view(a)
     struct ggml_tensor * ggml_alibi(
@@ -740,13 +926,13 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_map_unary_f32(
             struct ggml_context        * ctx,
             struct ggml_tensor         * a,
-            const  ggml_unary_op_f32_t fun);
+                   ggml_unary_op_f32_t   fun);
 
     GGML_API struct ggml_tensor * ggml_map_binary_f32(
             struct ggml_context         * ctx,
             struct ggml_tensor          * a,
             struct ggml_tensor          * b,
-            const  ggml_binary_op_f32_t fun);
+                   ggml_binary_op_f32_t   fun);
 
     //
     // automatic differentiation

+ 9 - 7
llama.cpp

@@ -1128,8 +1128,8 @@ static bool llama_eval_internal(
         // self-attention
         {
             // compute Q and K and RoPE them
-            struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
-            struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
             ggml_set_name(Qcur, "Qcur");
             ggml_set_name(Kcur, "Kcur");
 
@@ -1170,17 +1170,19 @@ static bool llama_eval_internal(
             struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
             ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");
 
-            struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
+            // KQ_scaled shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
             ggml_set_name(KQ_scaled, "KQ_scaled");
 
             // KQ_masked = mask_past(KQ_scaled)
-            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
             ggml_set_name(KQ_masked, "KQ_masked");
 
             // KQ = soft_max(KQ_masked)
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
             ggml_set_name(KQ_soft_max, "KQ_soft_max");
 
+
             // split cached V into n_head heads
             struct ggml_tensor * V =
                 ggml_view_3d(ctx0, kv_self.v,
@@ -1281,7 +1283,7 @@ static bool llama_eval_internal(
     lctx.use_buf(ctx0, -1);
 
     // logits -> probs
-    //inpL = ggml_soft_max(ctx0, inpL);
+    //inpL = ggml_soft_max_inplace(ctx0, inpL);
 
     // run the computation
     ggml_build_forward_expand(&gf, inpL);
@@ -2375,7 +2377,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
 
             if (scaling != 1.0f) {
                 ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
-                BA = ggml_scale(lora_ctx, BA, scale_tensor);
+                BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor);
             }
 
             ggml_tensor * r;

+ 2 - 0
tests/CMakeLists.txt

@@ -10,3 +10,5 @@ llama_add_test(test-quantize-fns.cpp)
 llama_add_test(test-quantize-perf.cpp)
 llama_add_test(test-sampling.cpp)
 llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
+# llama_add_test(test-grad0.c) # SLOW
+# llama_add_test(test-opt.c) # SLOW

+ 1131 - 0
tests/test-grad0.c

@@ -0,0 +1,1131 @@
+#include "ggml.h"
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+
+#define MAX_NARGS 2
+
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+#define GGML_SILU_FP16
+
+//
+// logging
+//
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
+float frand(void) {
+    return (float)rand()/(float)RAND_MAX;
+}
+
+int irand(int n) {
+    if (n == 0) return 0;
+    else return rand()%n;
+}
+
+void get_random_dims(int64_t * dims, int ndims) {
+    dims[0] = dims[1] = dims[2] = dims[3] = 1;
+
+    for (int i = 0; i < ndims; i++) {
+        dims[i] = 1 + irand(4);
+    }
+}
+
+struct ggml_tensor * get_random_tensor(
+        struct ggml_context * ctx0,
+        int ndims,
+        int64_t ne[],
+        float fmin,
+        float fmax) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return result;
+}
+
+struct ggml_tensor * get_random_tensor_int(
+        struct ggml_context * ctx0,
+        int ndims,
+        int64_t ne[],
+        int32_t imin,
+        int32_t imax) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_I32, ndims, ne);
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((int32_t *)result->data)[i0] = irand(imax - imin) + imin;
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((int32_t *)result->data)[i1*ne[0] + i0] = irand(imax - imin) + imin;
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((int32_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((int32_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return result;
+}
+
+float get_element(const struct ggml_tensor * t, int idx) {
+    if (t->type == GGML_TYPE_F32) {
+        return ((float *)t->data)[idx];
+    } else if (t->type == GGML_TYPE_I32) {
+        return ((int32_t *)t->data)[idx];
+    } else {
+        assert(false);
+        return INFINITY;
+    }
+}
+
+void set_element(struct ggml_tensor * t, int idx, float value) {
+    ((float *)t->data)[idx] = value;
+}
+
+void print_elements(const char* label, const struct ggml_tensor * t) {
+    if (!t) {
+        printf("%s: %s = null\n", __func__, label);
+        return;
+    }
+    const int nelements = ggml_nelements(t);
+    printf("%s: %s = [", __func__, label);
+    for (int k = 0; k < nelements; ++k) {
+        if (k > 0) { printf(", "); }
+        printf("%.5f", get_element(t, k));
+    }
+    printf("] shape: [");
+    for (int k = 0; k < t->n_dims; ++k) {
+        if (k > 0) { printf(", "); }
+        printf("%d", (int)t->ne[k]);
+    }
+    printf("]\n");
+
+}
+
+bool check_gradient(
+        const char * op_name,
+        struct ggml_context * ctx0,
+        struct ggml_tensor * x[],
+        struct ggml_tensor * f,
+        int ndims,
+        int nargs,
+        float eps,
+        float max_error_abs,
+        float max_error_rel) {
+
+    struct ggml_cgraph gf = ggml_build_forward (f);
+    struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+    ggml_graph_compute(ctx0, &gf);
+    ggml_graph_reset  (&gf);
+    ggml_set_f32      (f->grad, 1.0f);
+    ggml_graph_compute(ctx0, &gb);
+
+    // ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
+    // ggml_graph_dump_dot(&gb, &gf,  "test-grad0-backward.dot");
+
+    for (int i = 0; i < nargs; ++i) {
+        const int nelements = ggml_nelements(x[i]);
+        for (int k = 0; k < nelements; ++k) {
+            // compute gradient using finite differences
+            const float x0 = get_element(x[i], k);
+            const float xm = x0 - eps;
+            const float xp = x0 + eps;
+            set_element(x[i], k, xp);
+            ggml_graph_compute(ctx0, &gf);
+
+            const float f0 = ggml_get_f32_1d(f, 0);
+
+            set_element(x[i], k, xm);
+            ggml_graph_compute(ctx0, &gf);
+
+            const float f1 = ggml_get_f32_1d(f, 0);
+
+            const float g0 = (f0 - f1)/(2.0f*eps);
+
+            set_element(x[i], k, x0);
+
+            // compute gradient using backward graph
+            ggml_graph_reset  (&gf);
+            ggml_set_f32      (f->grad, 1.0f);
+            ggml_graph_compute(ctx0, &gb);
+
+            const float g1 = get_element(x[i]->grad, k);
+
+            const float error_abs = fabsf(g0 - g1);
+            const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabs(g0) : 0;
+
+            if (error_abs > max_error_abs || error_rel > max_error_rel) {
+                printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
+                            op_name, ndims, i, k, x0, xm, xp, f0, f1, g0, g1, eps, error_abs, error_rel);
+                //assert(false);
+                return false;
+            }
+        }
+    }
+
+    return true;
+}
+
+// TODO: clean-up this ..
+bool check_mat_mul(
+        const struct ggml_tensor * y,
+        const struct ggml_tensor * x0,
+        const struct ggml_tensor * x1) {
+    float * dst  = (float *) y->data;
+    float * src0 = (float *) x0->data;
+    float * src1 = (float *) x1->data;
+
+    const int nc = x0->ne[1];
+    const int nr = x1->ne[1];
+    const int nk = x0->ne[0];
+
+    GGML_PRINT_DEBUG("check_mat_mul: nc=%d, nr=%d, nk=%d\n", nc, nr, nk);
+
+    GGML_PRINT_DEBUG("x0:\n");
+    for (int j = 0; j < x0->ne[1]; ++j) {
+        for (int i = 0; i < x0->ne[0]; ++i) {
+            GGML_PRINT_DEBUG("%6.3f ", src0[j*nk + i]);
+        }
+        GGML_PRINT_DEBUG("\n");
+    }
+    GGML_PRINT_DEBUG("\n");
+
+    GGML_PRINT_DEBUG("x1:\n");
+    for (int j = 0; j < x1->ne[1]; ++j) {
+        for (int i = 0; i < x1->ne[0]; ++i) {
+            GGML_PRINT_DEBUG("%6.3f ", src1[j*nk + i]);
+        }
+        GGML_PRINT_DEBUG("\n");
+    }
+    GGML_PRINT_DEBUG("\n");
+
+    GGML_PRINT_DEBUG("y: n_dims = %d, (%lld, %lld)\n", y->n_dims, y->ne[0], y->ne[1]);
+    for (int j = 0; j < y->ne[1]; ++j) {
+        for (int i = 0; i < y->ne[0]; ++i) {
+            GGML_PRINT_DEBUG("%6.3f ", dst[j*nr + i]);
+        }
+        GGML_PRINT_DEBUG("\n");
+    }
+
+    for (int i = 0; i < nr; ++i) {
+        for (int j = 0; j < nc; ++j) {
+            float sum = 0.0f;
+
+            for (int k = 0; k < nk; ++k) {
+                sum += src0[j*nk + k]*src1[i*nk + k];
+            }
+
+            if (fabsf(dst[i*nc + j] - sum) > 1e-5f) {
+                fprintf(stderr, "check_mat_mul: dst[%d] = %f, sum = %f\n", i*nc + j, dst[i*nc + j], sum);
+                assert(false);
+                return false;
+            }
+        }
+    }
+
+    return true;
+}
+
+#define NUM_PERMUTATIONS (4*3*2*1)
+
+int main(int argc, const char ** argv) {
+    struct ggml_init_params params = {
+        .mem_size   = 128*1024*1024,
+        .mem_buffer = NULL,
+        .no_alloc   = false,
+    };
+
+    int64_t ne[4];
+
+    int all_permutations[4 * NUM_PERMUTATIONS];
+    {
+        int count = 0;
+        for (int ax0=0; ax0<4; ++ax0) {
+            for (int ax1=0; ax1<4; ++ax1) {
+                if (ax1 == ax0) continue;
+                for (int ax2=0; ax2<4; ++ax2) {
+                    if (ax2 == ax0) continue;
+                    if (ax2 == ax1) continue;
+                    for (int ax3=0; ax3<4; ++ax3) {
+                        if (ax3 == ax0) continue;
+                        if (ax3 == ax1) continue;
+                        if (ax3 == ax2) continue;
+                        assert(count < NUM_PERMUTATIONS);
+                        all_permutations[count*4+0] = ax0;
+                        all_permutations[count*4+1] = ax1;
+                        all_permutations[count*4+2] = ax2;
+                        all_permutations[count*4+3] = ax3;
+                        ++count;
+                    }
+                }
+            }
+        }
+    }
+
+
+    // original loop: 1000
+    int niter = 4;
+    const char *env = getenv("GGML_NLOOP");
+    if (env != NULL) {
+        niter = atoi(env);
+    }
+    if (argc > 1) {
+        niter = atoi(argv[1]);
+    }
+    for (int iter = 0; iter < niter; ++iter) {
+        printf("test-grad0: iter:%d/%d\n", iter, niter);
+        struct ggml_context * ctx0 = ggml_init(params);
+
+        get_random_dims(ne, 4);
+
+        struct ggml_tensor * x[MAX_NARGS];
+
+        // add
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
+
+                check_gradient("add", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
+            }
+        }
+
+        // sub
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1]));
+
+                check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // mul
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1]));
+
+                check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // div
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, 0.5f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
+
+                check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f);
+            }
+        }
+
+        // sqr
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0]));
+
+                check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // sqrt
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0]));
+
+                check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
+            }
+        }
+
+        // log
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0]));
+
+                check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
+            }
+        }
+
+        // sum
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, x[0]);
+
+                check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+
+        // sum_rows
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0])));
+
+                check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
+            }
+        }
+
+        // repeat
+        {
+            int64_t ne2[4];
+            get_random_dims(ne2, 4);
+
+            ne2[0] = ne[0] * ne2[0];
+            ne2[1] = ne[1] * ne2[1];
+            ne2[2] = 1;
+            ne2[3] = 1;
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
+
+                check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
+            }
+
+        }
+
+        // abs (finite differences do not work)
+        //{
+        //    const int nargs = 1;
+
+        //    for (int ndims = 1; ndims <= 2; ++ndims) {
+        //        for (int i = 0; i < nargs; ++i) {
+        //            x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+        //            ggml_set_param(ctx0, x[i]);
+        //        }
+
+        //        struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
+
+        //        check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f);
+        //    }
+        //}
+
+        // mul_mat
+        {
+            const int nargs = 2;
+
+            for (int ndims = 2; ndims <= 2; ++ndims) {
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                {
+                    int64_t ne2[4];
+                    get_random_dims(ne2, 4);
+                    ne2[0] = ne[0];
+                    x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+                }
+
+                ggml_set_param(ctx0, x[0]);
+                ggml_set_param(ctx0, x[1]);
+
+                struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
+                struct ggml_tensor * f = ggml_sum(ctx0, m);
+
+                GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
+
+                check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_mat_mul(m, x[1], x[0]);
+            }
+        }
+
+        // silu
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_silu(ctx0, x[0]));
+
+#ifdef GGML_SILU_FP16
+                // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds.
+                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY);
+#else
+                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+#endif
+            }
+        }
+
+        // rms_norm
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0]));
+
+                check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
+            }
+        }
+
+        // scale
+        {
+            const int nargs = 2;
+
+            int64_t ne2[4];
+            ne2[0] = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+
+                ggml_set_param(ctx0, x[0]);
+                ggml_set_param(ctx0, x[1]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], x[1]));
+
+                check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // cpy
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+                // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
+
+                check_gradient("cpy", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // reshape (1d->nd)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                int64_t ne2[4];
+                ne2[0] = 1;
+                ne2[1] = 1;
+                ne2[2] = 1;
+                ne2[3] = 1;
+                for (int i = 0; i < ndims; ++i) {
+                    ne2[0] *= ne[i];
+                }
+                x[0] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
+                x[1] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
+                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // reshape (nd->1d)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                int64_t ne2[4];
+                ne2[0] = 1;
+                ne2[1] = 1;
+                ne2[2] = 1;
+                ne2[3] = 1;
+                for (int i = 0; i < ndims; ++i) {
+                    ne2[0] *= ne[i];
+                }
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
+                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // acc 1d
+        {
+            int64_t ne2[4] = { 1, 1, 1, 1 };
+
+            const int nargs = 2;
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 1);
+                while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 1);
+                }
+
+                x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
+                const int offset = irand(max_offset) * ggml_element_size(x[0]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
+
+                check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // acc 2d
+        {
+            int64_t ne2[4]         = { 1, 1, 1, 1 };
+            int64_t max_offsets[4] = { 0, 0, 0, 0 };
+            int64_t offsets[4]     = { 0, 0, 0, 0 };
+
+            const int nargs = 2;
+            for (int ndims = 2; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 2);
+                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 2);
+                }
+
+                x[1] = get_random_tensor(ctx0, 2, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
+                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
+                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
+                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
+                const int offset = offsets[0] + offsets[1];
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
+
+                check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // acc 3d
+        {
+            int64_t ne2[4]         = { 1, 1, 1, 1 };
+            int64_t max_offsets[4] = { 0, 0, 0, 0 };
+            int64_t offsets[4]     = { 0, 0, 0, 0 };
+
+            const int nargs = 2;
+            for (int ndims = 3; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 3);
+                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 3);
+                }
+
+                x[1] = get_random_tensor(ctx0, 3, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
+                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
+                max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]);
+                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
+                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
+                offsets[2] = irand(max_offsets[2]) * x[0]->nb[2];
+                const int offset = offsets[0] + offsets[1] + offsets[2];
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
+
+                check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // acc 4d
+        {
+            int64_t ne2[4]         = { 1, 1, 1, 1 };
+            int64_t max_offsets[4] = { 0, 0, 0, 0 };
+            int64_t offsets[4]     = { 0, 0, 0, 0 };
+
+            const int nargs = 2;
+            for (int ndims = 4; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 4);
+                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[3] > ne[3]) || (ne2[0]*ne2[1]*ne2[2]*ne2[3] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 4);
+                }
+
+                x[1] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
+                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
+                max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]);
+                max_offsets[3] = MAX(0, x[0]->ne[3] - x[1]->ne[3]);
+                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
+                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
+                offsets[2] = irand(max_offsets[2]) * x[0]->nb[2];
+                offsets[3] = irand(max_offsets[3]) * x[0]->nb[3];
+                const int offset = offsets[0] + offsets[1] + offsets[2] + offsets[3];
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
+
+                check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // set_1d
+        {
+            int64_t ne2[4];
+
+            const int nargs = 2;
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 1);
+                while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 1);
+                }
+
+                x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
+                const int offset = irand(max_offset) * ggml_element_size(x[0]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset));
+
+                check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // set_2d
+        {
+            int64_t ne2[4];
+            int64_t max_offsets[4] = { 0, 0, 0, 0 };
+            int64_t offsets[4]     = { 0, 0, 0, 0 };
+
+            const int nargs = 1;
+            for (int ndims = 2; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 2);
+                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 2);
+                }
+
+                x[1] = get_random_tensor(ctx0, 2, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
+                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
+                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
+                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
+                const int offset = offsets[0] + offsets[1];
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset));
+
+                check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // view_1d
+        {
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+
+                ggml_set_param(ctx0, x[0]);
+
+                const int k0 = irand(ggml_nelements(x[0]));
+                const int k1 = irand(ggml_nelements(x[0]));
+                const int i0 = MIN(k0, k1);
+                const int i1 = MAX(k0, k1);
+
+                const int offset = i0 * sizeof(float);
+                const int nelem  = i1 - i0;
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset));
+
+                check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // view_2d
+        {
+            int64_t ne2[4];
+            int64_t nb2[4];
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+
+                get_random_dims(ne2, 2);
+                while (ne2[0]*ne2[1] > ggml_nelements(x[0])) {
+                    get_random_dims(ne2, 2);
+                }
+                const int count = ne2[0]*ne2[1];
+
+                nb2[0] = sizeof(float);
+                nb2[1] = nb2[0]*ne2[0];
+
+                ggml_set_param(ctx0, x[0]);
+
+                const int max_offset = ggml_nelements(x[0]) - count;
+                const int offset = irand(max_offset+1) * sizeof(float);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset));
+
+                check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // view_3d
+        {
+            int64_t ne2[4] = {1,1,1,1};
+            int64_t nb2[4] = {0,0,0,0};
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+
+                get_random_dims(ne2, 3);
+                while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) {
+                    get_random_dims(ne2, 3);
+                }
+                const int count = ne2[0]*ne2[1]*ne2[2];
+
+                nb2[0] = sizeof(float);
+                nb2[1] = nb2[0]*ne2[0];
+                nb2[2] = nb2[1]*ne2[1];
+
+                ggml_set_param(ctx0, x[0]);
+
+                const int max_offset = ggml_nelements(x[0]) - count;
+                const int offset = irand(max_offset+1) * sizeof(float);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset));
+
+                check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // permute
+        {
+            int64_t ne2[4];
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 4; ++ndims)
+            {
+                // ggml_permute will set axes of dimensions below n_dims to 1.
+                // to make ggml_permute work correctly on all axes,
+                // the input tensor needs maximal n_dim of 4.
+                for (int i=0; i<ndims; ++i) {
+                    ne2[i] = ne[i];
+                }
+                for (int i=ndims; i<4; ++i) {
+                    ne2[i] = 1;
+                }
+                x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
+
+                ggml_set_param(ctx0, x[0]);
+
+                const int p = irand(NUM_PERMUTATIONS);
+                const int ax0 = all_permutations[p*4+0];
+                const int ax1 = all_permutations[p*4+1];
+                const int ax2 = all_permutations[p*4+2];
+                const int ax3 = all_permutations[p*4+3];
+
+                // sum requires contiguous tensor rows
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, x[0], ax0, ax1, ax2, ax3)));
+
+                check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // transpose
+        {
+            int64_t ne2[4];
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 4; ++ndims)
+            {
+                // ggml_transpose will set axes of dimensions below n_dims to 1.
+                // to make ggml_transpose work correctly on all axes,
+                // the input tensor needs maximal n_dim of 4.
+                for (int i=0; i<ndims; ++i) {
+                    ne2[i] = ne[i];
+                }
+                for (int i=ndims; i<4; ++i) {
+                    ne2[i] = 1;
+                }
+                x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
+
+                ggml_set_param(ctx0, x[0]);
+
+                // sum requires contiguous tensor rows
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, x[0])));
+
+                check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // get_rows
+        {
+            int64_t ne2[4] = {ne[0], ne[1], 1, 1};
+            int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
+            const int nargs = 1;
+            const int ndims = 2;
+            x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+            x[1] = get_random_tensor_int(ctx0, 1, ne3, 0, ne2[1]);
+
+            ggml_set_param(ctx0, x[0]);
+
+            struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1]));
+
+            check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+        }
+
+        // diag_mask_inf
+        {
+            const int nargs = 1;
+            const int ndims = 2;
+
+            x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+            ggml_set_param(ctx0, x[0]);
+
+            int n_past = irand(ne[0]);
+
+            struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_inf(ctx0, x[0], n_past));
+
+            check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+        }
+
+        // diag_mask_zero
+        {
+            const int nargs = 1;
+            const int ndims = 2;
+
+            x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+            ggml_set_param(ctx0, x[0]);
+
+            int n_past = irand(ne[0]);
+
+            struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_zero(ctx0, x[0], n_past));
+
+            check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+        }
+
+        // softmax
+        {
+            const int nargs = 1;
+
+            int64_t ne2[4];
+            get_random_dims(ne2, 4);
+
+            for (int ndims = 1; ndims <= 3; ++ndims) {
+                x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0]));
+
+                check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // rope
+        {
+            const int nargs = 1;
+
+            int64_t ne2[4];
+            get_random_dims(ne2, 4);
+            ne2[0] += ne2[0] % 2;
+            int n_rot = ne2[0];
+
+            for (int ndims = 3; ndims <= 4; ++ndims) {
+                for (int mode = 0; mode < 4; ++mode) {
+                    for (int n_past = 1; n_past < ne2[2]; ++n_past) {
+                        x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+
+                        ggml_set_param(ctx0, x[0]);
+
+                        const bool skip_past = (mode & 1);
+                        if (skip_past) {
+                            // we have no past, so this would have to work on uninitialized memory.
+                            // we only test the gradients here;
+                            // skip_past should have no influence on gradient computation.
+                            // so when other modes work, we assume that this does as well.
+                            continue;
+                        }
+
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode));
+
+                        GGML_PRINT_DEBUG("rope: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
+                        check_gradient("rope", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
+                    }
+                }
+            }
+        }
+
+        ggml_free(ctx0);
+    }
+
+    return 0;
+}

+ 205 - 0
tests/test-opt.c

@@ -0,0 +1,205 @@
+#include "ggml.h"
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+
+#define MAX_NARGS 2
+
+
+//
+// logging
+//
+#define GGML_DEBUG 0
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
+
+float frand() {
+    return (float)rand()/(float)RAND_MAX;
+}
+
+int irand(int n) {
+    return rand()%n;
+}
+
+void get_random_dims(int64_t * dims, int ndims) {
+    dims[0] = dims[1] = dims[2] = dims[3] = 1;
+
+    for (int i = 0; i < ndims; i++) {
+        dims[i] = 1 + irand(4);
+    }
+}
+
+void get_random_dims_minmax(int64_t * dims, int ndims, int min, int max) {
+    dims[0] = dims[1] = dims[2] = dims[3] = 1;
+
+    for (int i = 0; i < ndims; i++) {
+        dims[i] = min + irand(max-min);
+    }
+}
+
+
+struct ggml_tensor * get_random_tensor(
+        struct ggml_context * ctx0,
+        int ndims,
+        int64_t ne[],
+        float fmin,
+        float fmax) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return result;
+}
+
+float get_element(const struct ggml_tensor * t, int idx) {
+    return ((float *)t->data)[idx];
+}
+
+void set_element(struct ggml_tensor * t, int idx, float value) {
+    ((float *)t->data)[idx] = value;
+}
+
+int main(int argc, const char ** argv) {
+    struct ggml_init_params params = {
+        .mem_size   = 1024*1024*1024,
+        .mem_buffer = NULL,
+        .no_alloc   = false,
+    };
+    struct ggml_context * ctx = ggml_init(params);
+
+    int64_t ne1[4] = {4, 1024, 1, 1};
+    int64_t ne2[4] = {4, 2048, 1, 1};;
+    int64_t ne3[4] = {1024, 2048, 1, 1};
+
+    struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
+    struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1);
+    ggml_set_param(ctx, a);
+    ggml_set_param(ctx, b);
+
+    struct ggml_tensor * c = get_random_tensor(ctx, 2, ne3, -1, +1);
+
+    struct ggml_tensor * ab = ggml_mul_mat(ctx, a, b);
+    struct ggml_tensor * d  = ggml_sub(ctx, c, ab);
+    struct ggml_tensor * e  = ggml_sum(ctx, ggml_sqr(ctx, d));
+
+
+    struct ggml_cgraph ge = ggml_build_forward(e);
+    ggml_graph_reset  (&ge);
+    ggml_graph_compute(ctx, &ge);
+    const float fe = ggml_get_f32_1d(e, 0);
+    printf("%s: e = %.4f\n", __func__, fe);
+
+    struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM);
+
+    ggml_opt(ctx, opt_params, e);
+
+    ggml_graph_reset  (&ge);
+    ggml_graph_compute(ctx, &ge);
+    const float fe_opt = ggml_get_f32_1d(e, 0);
+    printf("%s: original  e = %.4f\n", __func__, fe);
+    printf("%s: optimized e = %.4f\n", __func__, fe_opt);
+
+    const bool success = (fe_opt <= fe);
+    assert(success);
+
+    ggml_free(ctx);
+    return success ? 0 : -1;
+}
+// int64_t ne1[4] = {4, 128, 1, 1};
+// int64_t ne2[4] = {4, 256, 1, 1};;
+// int64_t ne3[4] = {128, 256, 1, 1};
+// main: original  e = 25890.9375
+// main: optimized e = 10094.7031
+
+// int64_t ne1[4] = {8, 128, 1, 1};
+// int64_t ne2[4] = {8, 256, 1, 1};;
+// int64_t ne3[4] = {128, 256, 1, 1};
+// main: original  e = 39429.5078
+// main: optimized e = 9275.8936
+
+// int64_t ne1[4] = {16, 128, 1, 1};
+// int64_t ne2[4] = {16, 256, 1, 1};;
+// int64_t ne3[4] = {128, 256, 1, 1};
+// main: original  e = 68371.1328
+// main: optimized e = 7854.4502
+
+
+// int64_t ne1[4] = {32, 128, 1, 1};
+// int64_t ne2[4] = {32, 256, 1, 1};;
+// int64_t ne3[4] = {128, 256, 1, 1};
+// main: original  e = 126061.1953
+// main: optimized e = 5451.0166
+
+// int64_t ne1[4] = {4, 1024, 1, 1};
+// int64_t ne2[4] = {4, 2048, 1, 1};;
+// int64_t ne3[4] = {1024, 2048, 1, 1};
+// main: original  e = 1620817.8750
+// main: optimized e = 698387.6875
+
+// another run on M1
+// int64_t ne1[4] = {4, 1024, 1, 1};
+// int64_t ne2[4] = {4, 2048, 1, 1};;
+// int64_t ne3[4] = {1024, 2048, 1, 1};
+// main: original  e = 1629595.6250
+// main: optimized e = 698169.1250
+
+// int64_t ne1[4] = {32, 1024, 1, 1};
+// int64_t ne2[4] = {32, 2048, 1, 1};;
+// int64_t ne3[4] = {1024, 2048, 1, 1};
+// main: original  e = 8146770.5000
+// main: optimized e = 651119.1250

Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff