Browse Source

llama : refactor llama_context, llama_kv_cache, llm_build_context (#12181)

* llama : refactor llama_context, llama_kv_cache, llm_build_context

ggml-ci

* graph : don't mutate the KV cache during defrag

ggml-ci

* context : reduce virtuals + remove test function

ggml-ci

* context : move interface implementation to source file + factory

ggml-ci

* graph : move KV cache build functions to llama_context impl

ggml-ci

* graph : remove model reference from build_pooling

ggml-ci

* graph : remove llama_model reference

ggml-ci

* kv_cache : provide rope factors

ggml-ci

* graph : rework inputs to use only unique_ptr, remove attn input abstraction

ggml-ci

* context : remove llama_context_i abstraction

ggml-ci

* context : clean-up

ggml-ci

* graph : clean-up

ggml-ci

* llama : remove redundant keywords (struct, enum)

ggml-ci

* model : adapt gemma3

ggml-ci

* graph : restore same attention ops as on master

ggml-ci

* llama : remove TODO + fix indent

ggml-ci
Georgi Gerganov 10 tháng trước cách đây
mục cha
commit
e0dbec0bc6
46 tập tin đã thay đổi với 12134 bổ sung10879 xóa
  1. 3 3
      common/common.cpp
  2. 4 4
      common/speculative.cpp
  3. 2 2
      examples/batched-bench/batched-bench.cpp
  4. 1 1
      examples/batched.swift/Sources/main.swift
  5. 1 1
      examples/cvector-generator/cvector-generator.cpp
  6. 1 1
      examples/embedding/embedding.cpp
  7. 2 2
      examples/gritlm/gritlm.cpp
  8. 1 1
      examples/imatrix/imatrix.cpp
  9. 2 2
      examples/infill/infill.cpp
  10. 2 2
      examples/llama-bench/llama-bench.cpp
  11. 4 4
      examples/llama.android/llama/src/main/cpp/llama-android.cpp
  12. 4 4
      examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
  13. 1 1
      examples/llava/gemma3-cli.cpp
  14. 6 6
      examples/lookahead/lookahead.cpp
  15. 1 1
      examples/lookup/lookup.cpp
  16. 6 6
      examples/main/main.cpp
  17. 5 5
      examples/parallel/parallel.cpp
  18. 14 14
      examples/passkey/passkey.cpp
  19. 6 6
      examples/perplexity/perplexity.cpp
  20. 2 2
      examples/quantize-stats/quantize-stats.cpp
  21. 1 1
      examples/retrieval/retrieval.cpp
  22. 2 2
      examples/run/run.cpp
  23. 2 2
      examples/save-load-state/save-load-state.cpp
  24. 11 11
      examples/server/server.cpp
  25. 1 1
      examples/server/tests/utils.py
  26. 2 2
      examples/simple-chat/simple-chat.cpp
  27. 1 1
      examples/speculative-simple/speculative-simple.cpp
  28. 13 13
      examples/speculative/speculative.cpp
  29. 82 22
      include/llama.h
  30. 5 2
      src/CMakeLists.txt
  31. 19 20
      src/llama-adapter.cpp
  32. 11 9
      src/llama-adapter.h
  33. 2 2
      src/llama-batch.h
  34. 646 546
      src/llama-context.cpp
  35. 211 77
      src/llama-context.h
  36. 1695 0
      src/llama-graph.cpp
  37. 576 0
      src/llama-graph.h
  38. 15 0
      src/llama-io.cpp
  39. 35 0
      src/llama-io.h
  40. 989 281
      src/llama-kv-cache.cpp
  41. 178 110
      src/llama-kv-cache.h
  42. 1 0
      src/llama-memory.cpp
  43. 21 0
      src/llama-memory.h
  44. 7529 219
      src/llama-model.cpp
  45. 18 1
      src/llama-model.h
  46. 0 9489
      src/llama.cpp

+ 3 - 3
common/common.cpp

@@ -955,8 +955,8 @@ struct common_init_result common_init_from_params(common_params & params) {
         return iparams;
     }
 
-    if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
-        LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
+    if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
+        LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
         params.ctx_shift = false;
     }
 
@@ -1060,7 +1060,7 @@ struct common_init_result common_init_from_params(common_params & params) {
         if (llama_model_has_decoder(model)) {
             llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
         }
-        llama_kv_cache_clear(lctx);
+        llama_kv_self_clear(lctx);
         llama_synchronize(lctx);
         llama_perf_context_reset(lctx);
     }

+ 4 - 4
common/speculative.cpp

@@ -173,7 +173,7 @@ llama_tokens common_speculative_gen_draft(
     result.reserve(params.n_draft);
 
     if (reuse_n == 0) {
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         prompt.clear();
     } else {
@@ -192,14 +192,14 @@ llama_tokens common_speculative_gen_draft(
         }
 
         if (reuse_i > 0) {
-            llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
-            llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
+            llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
+            llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
 
             prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
         }
 
         if (reuse_n < (int) prompt.size()) {
-            llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
+            llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
 
             prompt.erase(prompt.begin() + reuse_n, prompt.end());
         }

+ 2 - 2
examples/batched-bench/batched-bench.cpp

@@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
 
                 const auto t_pp_start = ggml_time_us();
 
-                llama_kv_cache_clear(ctx);
+                llama_kv_self_clear(ctx);
 
                 if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
                     LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -141,7 +141,7 @@ int main(int argc, char ** argv) {
 
                 if (is_pp_shared) {
                     for (int32_t i = 1; i < pl; ++i) {
-                        llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
+                        llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
                     }
                 }
 

+ 1 - 1
examples/batched.swift/Sources/main.swift

@@ -116,7 +116,7 @@ if llama_decode(context, batch) != 0 {
 }
 
 for i in 1 ..< n_parallel {
-    llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
+    llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
 }
 
 if n_parallel > 1 {

+ 1 - 1
examples/cvector-generator/cvector-generator.cpp

@@ -342,7 +342,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
 }
 
 static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
-    llama_kv_cache_clear(ctx);
+    llama_kv_self_clear(ctx);
     if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
         fprintf(stderr, "%s : failed to eval\n", __func__);
         return false;

+ 1 - 1
examples/embedding/embedding.cpp

@@ -38,7 +38,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
     const struct llama_model * model = llama_get_model(ctx);
 
     // clear previous kv_cache values (irrelevant for embeddings)
-    llama_kv_cache_clear(ctx);
+    llama_kv_self_clear(ctx);
 
     // run model
     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

+ 2 - 2
examples/gritlm/gritlm.cpp

@@ -45,7 +45,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
         }
 
         // clear previous kv_cache values (irrelevant for embeddings)
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
         llama_set_embeddings(ctx, true);
         llama_set_causal_attn(ctx, false);
 
@@ -102,7 +102,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
 
     llama_token eos_token = llama_vocab_eos(vocab);
 
-    llama_kv_cache_clear(ctx);
+    llama_kv_self_clear(ctx);
     llama_set_embeddings(ctx, false);
     llama_set_causal_attn(ctx, true);
 

+ 1 - 1
examples/imatrix/imatrix.cpp

@@ -495,7 +495,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
         const auto t_start = std::chrono::high_resolution_clock::now();
 
         // clear the KV cache
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         llama_batch batch = llama_batch_init(n_batch, 0, 1);
 

+ 2 - 2
examples/infill/infill.cpp

@@ -332,8 +332,8 @@ int main(int argc, char ** argv) {
                 LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
                     n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
-                llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
+                llama_kv_self_seq_rm (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
+                llama_kv_self_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
 
                 n_past -= n_discard;
 

+ 2 - 2
examples/llama-bench/llama-bench.cpp

@@ -1578,7 +1578,7 @@ int main(int argc, char ** argv) {
 
         test t(inst, lmodel, ctx);
 
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         // cool off before the test
         if (params.delay) {
@@ -1618,7 +1618,7 @@ int main(int argc, char ** argv) {
         }
 
         for (int i = 0; i < params.reps; i++) {
-            llama_kv_cache_clear(ctx);
+            llama_kv_self_clear(ctx);
 
             uint64_t t_start = get_time_ns();
 

+ 4 - 4
examples/llama.android/llama/src/main/cpp/llama-android.cpp

@@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
         }
 
         batch->logits[batch->n_tokens - 1] = true;
-        llama_kv_cache_clear(context);
+        llama_kv_self_clear(context);
 
         const auto t_pp_start = ggml_time_us();
         if (llama_decode(context, *batch) != 0) {
@@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
 
         LOGi("Benchmark text generation (tg)");
 
-        llama_kv_cache_clear(context);
+        llama_kv_self_clear(context);
         const auto t_tg_start = ggml_time_us();
         for (i = 0; i < tg; i++) {
 
@@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
 
         const auto t_tg_end = ggml_time_us();
 
-        llama_kv_cache_clear(context);
+        llama_kv_self_clear(context);
 
         const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
         const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@@ -448,5 +448,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
 extern "C"
 JNIEXPORT void JNICALL
 Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
-    llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
+    llama_kv_self_clear(reinterpret_cast<llama_context *>(context));
 }

+ 4 - 4
examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

@@ -210,7 +210,7 @@ actor LlamaContext {
             }
             batch.logits[Int(batch.n_tokens) - 1] = 1 // true
 
-            llama_kv_cache_clear(context)
+            llama_kv_self_clear(context)
 
             let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
 
@@ -223,7 +223,7 @@ actor LlamaContext {
 
             // bench text generation
 
-            llama_kv_cache_clear(context)
+            llama_kv_self_clear(context)
 
             let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
 
@@ -242,7 +242,7 @@ actor LlamaContext {
 
             let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000;
 
-            llama_kv_cache_clear(context)
+            llama_kv_self_clear(context)
 
             let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
             let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@@ -292,7 +292,7 @@ actor LlamaContext {
     func clear() {
         tokens_list.removeAll()
         temporary_invalid_cchars.removeAll()
-        llama_kv_cache_clear(context)
+        llama_kv_self_clear(context)
     }
 
     private func tokenize(text: String, add_bos: Bool) -> [llama_token] {

+ 1 - 1
examples/llava/gemma3-cli.cpp

@@ -309,7 +309,7 @@ int main(int argc, char ** argv) {
             }
             if (line == "/clear") {
                 ctx.n_past = 0;
-                llama_kv_cache_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS
+                llama_kv_self_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS
                 LOG("Chat history cleared\n\n");
                 continue;
             }

+ 6 - 6
examples/lookahead/lookahead.cpp

@@ -96,7 +96,7 @@ int main(int argc, char ** argv) {
     llama_decode(ctx, llama_batch_get_one(&inp.back(),           1));
 
     for (int s = 1; s < W + G + 1; ++s) {
-        llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
+        llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
     }
 
     const auto t_enc_end = ggml_time_us();
@@ -438,17 +438,17 @@ int main(int argc, char ** argv) {
 
         // KV cache management
         // if no verification token matched, we simply remove all cells from this batch -> no fragmentation
-        llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
+        llama_kv_self_seq_rm(ctx, -1, n_past, -1);
 
         if (seq_id_best != 0) {
             // if a verification token matched, we keep the best sequence and remove the rest
             // this leads to some KV cache fragmentation
-            llama_kv_cache_seq_keep(ctx, seq_id_best);
-            llama_kv_cache_seq_cp  (ctx, seq_id_best, 0, -1, -1);
-            llama_kv_cache_seq_rm  (ctx, seq_id_best,    -1, -1);
+            llama_kv_self_seq_keep(ctx, seq_id_best);
+            llama_kv_self_seq_cp  (ctx, seq_id_best, 0, -1, -1);
+            llama_kv_self_seq_rm  (ctx, seq_id_best,    -1, -1);
 
             for (int s = 1; s < W + G + 1; ++s) {
-                llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
+                llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
             }
         }
     }

+ 1 - 1
examples/lookup/lookup.cpp

@@ -192,7 +192,7 @@ int main(int argc, char ** argv){
 
         // KV cache management
         // clean the cache of draft tokens that weren't accepted
-        llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
+        llama_kv_self_seq_rm(ctx, 0, n_past, -1);
 
         common_batch_clear(batch_tgt);
         common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

+ 6 - 6
examples/main/main.cpp

@@ -354,7 +354,7 @@ int main(int argc, char ** argv) {
         }
 
         // remove any "future" tokens that we might have inherited from the previous session
-        llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
+        llama_kv_self_seq_rm(ctx, -1, n_matching_session_tokens, -1);
     }
 
     LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
@@ -602,8 +602,8 @@ int main(int argc, char ** argv) {
                     LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
                             n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                    llama_kv_cache_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
-                    llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
+                    llama_kv_self_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
+                    llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
 
                     n_past -= n_discard;
 
@@ -626,9 +626,9 @@ int main(int argc, char ** argv) {
                     LOG_DBG("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
                     LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
 
-                    llama_kv_cache_seq_add(ctx, 0, ga_i,                n_past,              ib*bd);
-                    llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
-                    llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
+                    llama_kv_self_seq_add(ctx, 0, ga_i,                n_past,              ib*bd);
+                    llama_kv_self_seq_div(ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
+                    llama_kv_self_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
 
                     n_past -= bd;
 

+ 5 - 5
examples/parallel/parallel.cpp

@@ -202,7 +202,7 @@ int main(int argc, char ** argv) {
 
         // assign the system KV cache to all parallel sequences
         for (int32_t i = 1; i <= n_clients; ++i) {
-            llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
+            llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
         }
 
         LOG_INF("\n");
@@ -234,9 +234,9 @@ int main(int argc, char ** argv) {
         if (batch.n_tokens == 0) {
             // all sequences have ended - clear the entire KV cache
             for (int i = 1; i <= n_clients; ++i) {
-                llama_kv_cache_seq_rm(ctx, i, -1, -1);
+                llama_kv_self_seq_rm(ctx, i, -1, -1);
                 // but keep the system prompt
-                llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
+                llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
             }
 
             LOG_INF("%s: clearing the KV cache\n", __func__);
@@ -372,8 +372,8 @@ int main(int argc, char ** argv) {
                     }
 
                     // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
-                    llama_kv_cache_seq_rm(ctx,    client.id + 1, -1, -1);
-                    llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
+                    llama_kv_self_seq_rm(ctx,    client.id + 1, -1, -1);
+                    llama_kv_self_seq_cp(ctx, 0, client.id + 1, -1, -1);
 
                     const auto t_main_end = ggml_time_us();
 

+ 14 - 14
examples/passkey/passkey.cpp

@@ -133,11 +133,11 @@ int main(int argc, char ** argv) {
             const int ib = i/n_batch - 1;
             const int bd = n_batch_grp*(n_grp - 1);
 
-            llama_kv_cache_seq_add (ctx, 0, n_past - n_batch,         n_past,         ib*bd);
-            llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
-            llama_kv_cache_update  (ctx);
+            llama_kv_self_seq_add (ctx, 0, n_past - n_batch,         n_past,         ib*bd);
+            llama_kv_self_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
+            llama_kv_self_update  (ctx);
 
-            n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
+            n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
         }
 
         common_batch_clear(batch);
@@ -167,12 +167,12 @@ int main(int argc, char ** argv) {
 
         LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
 
-        llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
-        llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
-      //llama_kv_cache_defrag (ctx);
-        llama_kv_cache_update (ctx);
+        llama_kv_self_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
+        llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+      //llama_kv_self_defrag (ctx);
+        llama_kv_self_update (ctx);
 
-        n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
+        n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
 
         common_batch_clear(batch);
 
@@ -198,12 +198,12 @@ int main(int argc, char ** argv) {
         if (n_discard > 0) {
             LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
 
-            llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
-            llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
-          //llama_kv_cache_defrag (ctx);
-            llama_kv_cache_update (ctx);
+            llama_kv_self_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
+            llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+          //llama_kv_self_defrag (ctx);
+            llama_kv_self_update (ctx);
 
-            n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
+            n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
         }
     }
 

+ 6 - 6
examples/perplexity/perplexity.cpp

@@ -361,7 +361,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
         const auto t_start = std::chrono::high_resolution_clock::now();
 
         // clear the KV cache
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         llama_batch batch = llama_batch_init(n_batch, 0, 1);
 
@@ -547,7 +547,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
         const auto t_start = std::chrono::high_resolution_clock::now();
 
         // clear the KV cache
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         for (int j = 0; j < num_batches; ++j) {
             const int batch_start = start + j * n_batch;
@@ -924,7 +924,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
             return;
         }
 
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         // decode all tasks [i0, i1)
         if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1203,7 +1203,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
             return;
         }
 
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         // decode all tasks [i0, i1)
         if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1575,7 +1575,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
             return;
         }
 
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         // decode all tasks [i0, i1)
         if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1765,7 +1765,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
         }
 
         // clear the KV cache
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         llama_batch batch = llama_batch_init(n_batch, 0, 1);
 

+ 2 - 2
examples/quantize-stats/quantize-stats.cpp

@@ -1,6 +1,6 @@
 #include "ggml.h"
 #include "llama.h"
-#include "llama-context.h"
+#include "llama-model.h"
 #include "common.h"
 
 #include <algorithm>
@@ -328,7 +328,7 @@ int main(int argc, char ** argv) {
         }
     }
 
-    const auto & tensors = llama_internal_get_tensor_map(ctx);
+    const auto & tensors = llama_internal_get_tensor_map(model);
 
     // check layer tensors
     int included_layers = 0;

+ 1 - 1
examples/retrieval/retrieval.cpp

@@ -83,7 +83,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
 
 static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
     // clear previous kv_cache values (irrelevant for embeddings)
-    llama_kv_cache_clear(ctx);
+    llama_kv_self_clear(ctx);
 
     // run model
     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

+ 2 - 2
examples/run/run.cpp

@@ -891,7 +891,7 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
 // Function to tokenize the prompt
 static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
                            std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
-    const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
+    const bool is_first = llama_kv_self_used_cells(llama_data.context.get()) == 0;
 
     const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
     prompt_tokens.resize(n_prompt_tokens);
@@ -907,7 +907,7 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
 // Check if we have enough space in the context to evaluate this batch
 static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
     const int n_ctx      = llama_n_ctx(ctx.get());
-    const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
+    const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
     if (n_ctx_used + batch.n_tokens > n_ctx) {
         printf(LOG_COL_DEFAULT "\n");
         printe("context size exceeded\n");

+ 2 - 2
examples/save-load-state/save-load-state.cpp

@@ -15,7 +15,7 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    print_build_info();
+    common_init();
 
     if (params.n_predict < 0) {
         params.n_predict = 16;
@@ -196,7 +196,7 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
 
         // erase whole kv
-        llama_kv_cache_clear(ctx3);
+        llama_kv_self_clear(ctx3);
         fprintf(stderr, "%s : kv cache cleared\n", __func__);
 
         // restore kv into seq 1

+ 11 - 11
examples/server/server.cpp

@@ -2113,7 +2113,7 @@ struct server_context {
         SRV_DBG("%s", "clearing KV cache\n");
 
         // clear the entire KV cache
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
         clean_kv_cache = false;
     }
 
@@ -2655,8 +2655,8 @@ struct server_context {
                     res->n_tasks_deferred    = queue_tasks.queue_tasks_deferred.size();
                     res->t_start             = metrics.t_start;
 
-                    res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
-                    res->kv_cache_used_cells   = llama_get_kv_cache_used_cells(ctx);
+                    res->kv_cache_tokens_count = llama_kv_self_n_tokens(ctx);
+                    res->kv_cache_used_cells   = llama_kv_self_used_cells(ctx);
 
                     res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
                     res->t_prompt_processing_total       = metrics.t_prompt_processing_total;
@@ -2772,7 +2772,7 @@ struct server_context {
 
                     // Erase token cache
                     const size_t n_erased = slot->cache_tokens.size();
-                    llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
+                    llama_kv_self_seq_rm(ctx, slot->id, -1, -1);
                     slot->cache_tokens.clear();
 
                     auto res = std::make_unique<server_task_result_slot_erase>();
@@ -2840,8 +2840,8 @@ struct server_context {
 
                 SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
 
-                llama_kv_cache_seq_rm (ctx, slot.id, n_keep            , n_keep + n_discard);
-                llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past,        -n_discard);
+                llama_kv_self_seq_rm (ctx, slot.id, n_keep            , n_keep + n_discard);
+                llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past,        -n_discard);
 
                 if (slot.params.cache_prompt) {
                     for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@@ -3032,8 +3032,8 @@ struct server_context {
 
                                             const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
 
-                                            llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
-                                            llama_kv_cache_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
+                                            llama_kv_self_seq_rm (ctx, slot.id, head_p, head_c);
+                                            llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
 
                                             for (size_t i = 0; i < n_match; i++) {
                                                 slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
@@ -3071,9 +3071,9 @@ struct server_context {
                     }
 
                     // keep only the common part
-                    if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
+                    if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) {
                         // could not partially delete (likely using a non-Transformer model)
-                        llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
+                        llama_kv_self_seq_rm(ctx, slot.id, -1, -1);
 
                         // there is no common part left
                         slot.n_past = 0;
@@ -3313,7 +3313,7 @@ struct server_context {
                 slot.cache_tokens.push_back(id);
                 slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
 
-                llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
+                llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
 
                 for (size_t i = 0; i < ids.size(); ++i) {
                     completion_token_output result;

+ 1 - 1
examples/server/tests/utils.py

@@ -302,7 +302,7 @@ class ServerPreset:
         server.model_hf_repo = "ggml-org/models"
         server.model_hf_file = "tinyllamas/stories260K.gguf"
         server.model_alias = "tinyllama-2"
-        server.n_ctx = 256
+        server.n_ctx = 512
         server.n_batch = 32
         server.n_slots = 2
         server.n_predict = 64

+ 2 - 2
examples/simple-chat/simple-chat.cpp

@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
     auto generate = [&](const std::string & prompt) {
         std::string response;
 
-        const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
+        const bool is_first = llama_kv_self_used_cells(ctx) == 0;
 
         // tokenize the prompt
         const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
@@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
         while (true) {
             // check if we have enough space in the context to evaluate this batch
             int n_ctx = llama_n_ctx(ctx);
-            int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
+            int n_ctx_used = llama_kv_self_used_cells(ctx);
             if (n_ctx_used + batch.n_tokens > n_ctx) {
                 printf("\033[0m\n");
                 fprintf(stderr, "context size exceeded\n");

+ 1 - 1
examples/speculative-simple/speculative-simple.cpp

@@ -217,7 +217,7 @@ int main(int argc, char ** argv) {
         {
             LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
 
-            llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
+            llama_kv_self_seq_rm(ctx_tgt, 0, n_past, -1);
         }
 
         if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {

+ 13 - 13
examples/speculative/speculative.cpp

@@ -420,14 +420,14 @@ int main(int argc, char ** argv) {
             {
                 LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
 
-                llama_kv_cache_seq_keep(ctx_dft, s_keep);
-                llama_kv_cache_seq_cp  (ctx_dft, s_keep, 0, -1, -1);
-                llama_kv_cache_seq_keep(ctx_dft, 0);
-
-                llama_kv_cache_seq_rm  (ctx_tgt, s_keep, n_past_tgt, -1);
-                llama_kv_cache_seq_keep(ctx_tgt, s_keep);
-                llama_kv_cache_seq_cp  (ctx_tgt, s_keep, 0, -1, -1);
-                llama_kv_cache_seq_keep(ctx_tgt, 0);
+                llama_kv_self_seq_keep(ctx_dft, s_keep);
+                llama_kv_self_seq_cp  (ctx_dft, s_keep, 0, -1, -1);
+                llama_kv_self_seq_keep(ctx_dft, 0);
+
+                llama_kv_self_seq_rm  (ctx_tgt, s_keep, n_past_tgt, -1);
+                llama_kv_self_seq_keep(ctx_tgt, s_keep);
+                llama_kv_self_seq_cp  (ctx_tgt, s_keep, 0, -1, -1);
+                llama_kv_self_seq_keep(ctx_tgt, 0);
             }
 
             for (int s = 0; s < n_seq_dft; ++s) {
@@ -444,7 +444,7 @@ int main(int argc, char ** argv) {
             common_batch_clear(batch_dft);
             common_batch_add  (batch_dft, token_id, n_past_dft, { 0 }, true);
 
-            llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
+            llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
             // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
             llama_decode(ctx_dft, batch_dft);
 
@@ -503,8 +503,8 @@ int main(int argc, char ** argv) {
                     if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
                         LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
 
-                        llama_kv_cache_seq_rm(ctx_dft,    n_seq_cur, -1, -1);
-                        llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
+                        llama_kv_self_seq_rm(ctx_dft,    n_seq_cur, -1, -1);
+                        llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
 
                         // all previous tokens from this branch are now also part of the new branch
                         for (int t = 0; t < batch_tgt.n_tokens; ++t) {
@@ -585,9 +585,9 @@ int main(int argc, char ** argv) {
 
         // evaluate the target model on the drafted tokens
         {
-            llama_kv_cache_seq_keep(ctx_tgt, 0);
+            llama_kv_self_seq_keep(ctx_tgt, 0);
             for (int s = 1; s < n_seq_dft; ++s) {
-                llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
+                llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
             }
 
             // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());

+ 82 - 22
include/llama.h

@@ -60,6 +60,7 @@ extern "C" {
     struct llama_model;
     struct llama_context;
     struct llama_sampler;
+    struct llama_kv_cache;
 
     typedef int32_t llama_pos;
     typedef int32_t llama_token;
@@ -469,7 +470,8 @@ extern "C" {
     DEPRECATED(LLAMA_API int32_t llama_n_vocab    (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
 
     LLAMA_API const struct llama_model * llama_get_model   (const struct llama_context * ctx);
-    LLAMA_API enum llama_pooling_type    llama_pooling_type(const struct llama_context * ctx);
+    LLAMA_API    struct llama_kv_cache * llama_get_kv_self (      struct llama_context * ctx);
+    LLAMA_API  enum llama_pooling_type   llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
 
     LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
     LLAMA_API enum llama_rope_type       llama_model_rope_type(const struct llama_model * model);
@@ -586,7 +588,7 @@ extern "C" {
     // KV cache
     //
 
-    // TODO: remove llama_kv_cache_view_* API
+    // TODO: start using struct llama_kv_cache
 
     // Information associated with an individual cell in the KV cache view.
     struct llama_kv_cache_view_cell {
@@ -641,13 +643,19 @@ extern "C" {
 
     // Returns the number of tokens in the KV cache (slow, use only for debug)
     // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
-    LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
+    LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
+
+    DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
+            "use llama_kv_self_n_tokens instead");
 
     // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
-    LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
+    LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
+
+    DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
+            "use llama_kv_self_used_cells instead");
 
     // Clear the KV cache - both cell info is erased and KV data is zeroed
-    LLAMA_API void llama_kv_cache_clear(
+    LLAMA_API void llama_kv_self_clear(
             struct llama_context * ctx);
 
     // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
@@ -655,7 +663,7 @@ extern "C" {
     // seq_id < 0 : match any sequence
     // p0 < 0     : [0,  p1]
     // p1 < 0     : [p0, inf)
-    LLAMA_API bool llama_kv_cache_seq_rm(
+    LLAMA_API bool llama_kv_self_seq_rm(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
@@ -665,7 +673,7 @@ extern "C" {
     // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    LLAMA_API void llama_kv_cache_seq_cp(
+    LLAMA_API void llama_kv_self_seq_cp(
             struct llama_context * ctx,
                     llama_seq_id   seq_id_src,
                     llama_seq_id   seq_id_dst,
@@ -673,17 +681,17 @@ extern "C" {
                        llama_pos   p1);
 
     // Removes all tokens that do not belong to the specified sequence
-    LLAMA_API void llama_kv_cache_seq_keep(
+    LLAMA_API void llama_kv_self_seq_keep(
             struct llama_context * ctx,
                     llama_seq_id   seq_id);
 
     // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
     // If the KV cache is RoPEd, the KV data is updated accordingly:
     //   - lazily on next llama_decode()
-    //   - explicitly with llama_kv_cache_update()
+    //   - explicitly with llama_kv_self_update()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    LLAMA_API void llama_kv_cache_seq_add(
+    LLAMA_API void llama_kv_self_seq_add(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
@@ -693,10 +701,10 @@ extern "C" {
     // Integer division of the positions by factor of `d > 1`
     // If the KV cache is RoPEd, the KV data is updated accordingly:
     //   - lazily on next llama_decode()
-    //   - explicitly with llama_kv_cache_update()
+    //   - explicitly with llama_kv_self_update()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    LLAMA_API void llama_kv_cache_seq_div(
+    LLAMA_API void llama_kv_self_seq_div(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
@@ -704,24 +712,76 @@ extern "C" {
                              int   d);
 
     // Returns the largest position present in the KV cache for the specified sequence
-    LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
+    LLAMA_API llama_pos llama_kv_self_seq_pos_max(
             struct llama_context * ctx,
-                    llama_seq_id   seq_id);
-
-    // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
-    //       how to avoid this?
+                     llama_seq_id   seq_id);
 
     // Defragment the KV cache
     // This will be applied:
     //   - lazily on next llama_decode()
-    //   - explicitly with llama_kv_cache_update()
-    LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
+    //   - explicitly with llama_kv_self_update()
+    LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
+
+    // Check if the context supports KV cache shifting
+    LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
 
     // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
-    LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
+    LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_clear(
+            struct llama_context * ctx),
+            "use llama_kv_self_clear instead");
+
+    DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id,
+                       llama_pos   p0,
+                       llama_pos   p1),
+            "use llama_kv_self_seq_rm instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id_src,
+                    llama_seq_id   seq_id_dst,
+                       llama_pos   p0,
+                       llama_pos   p1),
+            "use llama_kv_self_seq_cp instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id),
+            "use llama_kv_self_seq_keep instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id,
+                       llama_pos   p0,
+                       llama_pos   p1,
+                       llama_pos   delta),
+            "use llama_kv_self_seq_add instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id,
+                       llama_pos   p0,
+                       llama_pos   p1,
+                             int   d),
+            "use llama_kv_self_seq_div instead");
+
+    DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id),
+            "use llama_kv_self_seq_pos_max instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
+            "use llama_kv_self_defrag instead");
+
+    DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
+            "use llama_kv_self_can_shift instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
+            "use llama_kv_self_update instead");
 
-    // Check if the context supports KV cache shifting
-    LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
 
     //
     // State / sessions

+ 5 - 2
src/CMakeLists.txt

@@ -15,18 +15,21 @@ add_library(llama
             llama-chat.cpp
             llama-context.cpp
             llama-grammar.cpp
+            llama-graph.cpp
             llama-hparams.cpp
             llama-impl.cpp
+            llama-io.cpp
             llama-kv-cache.cpp
+            llama-memory.cpp
             llama-mmap.cpp
             llama-model-loader.cpp
             llama-model.cpp
             llama-quant.cpp
             llama-sampling.cpp
             llama-vocab.cpp
-            unicode.h
-            unicode.cpp
             unicode-data.cpp
+            unicode.cpp
+            unicode.h
             )
 
 target_include_directories(llama PUBLIC . ../include ../common)

+ 19 - 20
src/llama-adapter.cpp

@@ -4,14 +4,13 @@
 #include "llama-mmap.h"
 #include "llama-model.h"
 
-#include <algorithm>
 #include <map>
 #include <cassert>
 #include <stdexcept>
 
 // vec
 
-struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
+ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
     if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
         return nullptr;
     }
@@ -19,7 +18,7 @@ struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
     return tensors[il];
 }
 
-struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int  il) const {
+ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int  il) const {
     ggml_tensor * layer_dir = tensor_for(il);
     if (layer_dir != nullptr) {
         cur = ggml_add(ctx, cur, layer_dir);
@@ -40,7 +39,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
     auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
-            struct ggml_init_params params = {
+            ggml_init_params params = {
                 /*.mem_size   =*/ hparams.n_layer*ggml_tensor_overhead(),
                 /*.mem_buffer =*/ NULL,
                 /*.no_alloc   =*/ true,
@@ -91,7 +90,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
     return true;
 }
 
-int32_t llama_adapter_cvec::apply(
+bool llama_adapter_cvec::apply(
         const llama_model & model,
         const float * data,
         size_t len,
@@ -104,17 +103,17 @@ int32_t llama_adapter_cvec::apply(
         // disable the current control vector (but leave allocated for later)
         layer_start = -1;
         layer_end   = -1;
-        return 0;
+        return true;
     }
 
     if (n_embd != (int) hparams.n_embd) {
         LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
-        return 1;
+        return false;
     }
 
     if (tensors.empty()) {
         if (!init(model)) {
-            return 1;
+            return false;
         }
     }
 
@@ -130,12 +129,12 @@ int32_t llama_adapter_cvec::apply(
         }
     }
 
-    return 0;
+    return true;
 }
 
 // lora
 
-llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
+llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
     const std::string name(w->name);
 
     const auto pos = ab_map.find(name);
@@ -146,11 +145,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor *
     return nullptr;
 }
 
-static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
+static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
     LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
 
     ggml_context * ctx_init;
-    struct gguf_init_params meta_gguf_params = {
+    gguf_init_params meta_gguf_params = {
         /* .no_alloc = */ true,
         /* .ctx      = */ &ctx_init,
     };
@@ -201,7 +200,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
             // add a new context
-            struct ggml_init_params params = {
+            ggml_init_params params = {
                 /*.mem_size   =*/ n_tensors*ggml_tensor_overhead(),
                 /*.mem_buffer =*/ NULL,
                 /*.no_alloc   =*/ true,
@@ -264,7 +263,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
             throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
         }
 
-        struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
+        ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
         // validate tensor shape
         if (is_token_embd) {
             // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
@@ -281,8 +280,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
         }
 
         // save tensor to adapter
-        struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
-        struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
+        ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
+        ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
         ggml_set_name(tensor_a, w.a->name);
         ggml_set_name(tensor_b, w.b->name);
         adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
@@ -308,7 +307,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
     {
         llama_file gguf_file(path_lora, "rb");
         std::vector<uint8_t> read_buf;
-        auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
+        auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) {
             size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
             size_t size = ggml_nbytes(orig);
             read_buf.resize(size);
@@ -327,8 +326,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
     LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
 }
 
-struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
-    struct llama_adapter_lora * adapter = new llama_adapter_lora();
+llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
+    llama_adapter_lora * adapter = new llama_adapter_lora();
 
     try {
         llama_adapter_lora_init_impl(*model, path_lora, *adapter);
@@ -342,6 +341,6 @@ struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model,
     return nullptr;
 }
 
-void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
+void llama_adapter_lora_free(llama_adapter_lora * adapter) {
     delete adapter;
 }

+ 11 - 9
src/llama-adapter.h

@@ -15,11 +15,11 @@
 //
 
 struct llama_adapter_cvec {
-    struct ggml_tensor * tensor_for(int il) const;
+    ggml_tensor * tensor_for(int il) const;
 
-    struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int  il) const;
+    ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int  il) const;
 
-    int32_t apply(
+    bool apply(
             const llama_model & model,
             const float * data,
             size_t len,
@@ -36,7 +36,7 @@ private:
     std::vector<ggml_context_ptr> ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
-    std::vector<struct ggml_tensor *> tensors; // per layer
+    std::vector<ggml_tensor *> tensors; // per layer
 };
 
 //
@@ -44,8 +44,8 @@ private:
 //
 
 struct llama_adapter_lora_weight {
-    struct ggml_tensor * a = nullptr;
-    struct ggml_tensor * b = nullptr;
+    ggml_tensor * a = nullptr;
+    ggml_tensor * b = nullptr;
 
     // get actual scale based on rank and alpha
     float get_scale(float alpha, float adapter_scale) const {
@@ -55,12 +55,12 @@ struct llama_adapter_lora_weight {
     }
 
     llama_adapter_lora_weight() = default;
-    llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
+    llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {}
 };
 
 struct llama_adapter_lora {
     // map tensor name to lora_a_b
-    std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
+    std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
 
     std::vector<ggml_context_ptr> ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
@@ -70,5 +70,7 @@ struct llama_adapter_lora {
     llama_adapter_lora() = default;
     ~llama_adapter_lora() = default;
 
-    llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
+    llama_adapter_lora_weight * get_weight(ggml_tensor * w);
 };
+
+using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;

+ 2 - 2
src/llama-batch.h

@@ -42,9 +42,9 @@ struct llama_sbatch {
     bool logits_all; // TODO: remove once lctx.logits_all is removed too
 
     // sorted indices into the batch
-    std::vector<size_t> ids;
+    std::vector<int64_t> ids;
     // batch indices of the output
-    std::vector<size_t> out_ids;
+    std::vector<int64_t> out_ids;
     std::vector<llama_sbatch_seq> seq;
 
     const llama_batch * batch = nullptr;

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 646 - 546
src/llama-context.cpp


+ 211 - 77
src/llama-context.h

@@ -3,66 +3,210 @@
 #include "llama.h"
 #include "llama-batch.h"
 #include "llama-cparams.h"
-#include "llama-model.h"
-#include "llama-kv-cache.h"
+#include "llama-graph.h"
 #include "llama-adapter.h"
 
 #include "ggml-cpp.h"
 
 #include <map>
-#include <unordered_map>
 #include <vector>
-#include <set>
+
+struct llama_model;
+struct llama_kv_cache;
+
+class llama_io_read_i;
+class llama_io_write_i;
 
 struct llama_context {
-    llama_context(const llama_model & model)
-        : model(model)
-        , t_start_us(model.t_start_us)
-        , t_load_us(model.t_load_us) {}
+    // init scheduler and compute buffers, reserve worst-case graphs
+    llama_context(
+            const llama_model & model,
+                  llama_context_params params);
 
-    const struct llama_model & model;
+    ~llama_context();
 
-    struct llama_cparams      cparams;
-    struct llama_sbatch       sbatch;  // TODO: revisit if needed
-    struct llama_kv_cache     kv_self;
-    struct llama_adapter_cvec cvec;
+    void synchronize();
 
-    std::unordered_map<struct llama_adapter_lora *, float> lora;
+    const llama_model & get_model() const;
 
-    std::vector<ggml_backend_ptr> backends;
-    std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
+    uint32_t n_ctx()         const;
+    uint32_t n_ctx_per_seq() const;
+    uint32_t n_batch()       const;
+    uint32_t n_ubatch()      const;
+    uint32_t n_seq_max()     const;
 
-    ggml_backend_t backend_cpu = nullptr;
+    uint32_t n_threads()       const;
+    uint32_t n_threads_batch() const;
 
-    ggml_threadpool_t threadpool       = nullptr;
-    ggml_threadpool_t threadpool_batch = nullptr;
+          llama_kv_cache * get_kv_self();
+    const llama_kv_cache * get_kv_self() const;
 
-    bool has_evaluated_once = false;
+    void kv_self_update();
 
-    mutable int64_t t_start_us;
-    mutable int64_t t_load_us;
-    mutable int64_t t_p_eval_us = 0;
-    mutable int64_t t_eval_us   = 0;
+    enum llama_pooling_type pooling_type() const;
 
-    mutable int64_t t_compute_start_us = 0;
-    mutable int64_t n_queued_tokens = 0;
+    float * get_logits();
+    float * get_logits_ith(int32_t i);
 
-    mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
-    mutable int32_t n_eval   = 0; // number of eval calls
+    float * get_embeddings();
+    float * get_embeddings_ith(int32_t i);
+    float * get_embeddings_seq(llama_seq_id seq_id);
 
-    // host buffer for the model output (logits and embeddings)
-    ggml_backend_buffer_ptr buf_output;
+    void attach_threadpool(
+            ggml_threadpool_t threadpool,
+            ggml_threadpool_t threadpool_batch);
 
-    // decode output (2-dimensional array: [n_outputs][n_vocab])
-    size_t  logits_size = 0; // capacity (of floats) for logits
-    float * logits      = nullptr;
+    void detach_threadpool();
 
-    std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
-    size_t  output_size = 0; // capacity (of tokens positions) for the output buffers
-    int32_t n_outputs   = 0; // number of actually-used outputs in the current ubatch or last logical batch
+    void set_n_threads(int32_t n_threads, int32_t n_threads_batch);
+
+    void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
+
+    void set_embeddings (bool value);
+    void set_causal_attn(bool value);
+
+    void set_adapter_lora(
+            llama_adapter_lora * adapter,
+            float scale);
+
+    bool rm_adapter_lora(
+            llama_adapter_lora * adapter);
+
+    void clear_adapter_lora();
+
+    bool apply_adapter_cvec(
+            const float * data,
+                 size_t   len,
+                int32_t   n_embd,
+                int32_t   il_start,
+                int32_t   il_end);
+
+    int encode(llama_batch & inp_batch);
+    int decode(llama_batch & inp_batch);
+
+    //
+    // state save/load
+    //
+
+    size_t state_get_size();
+    size_t state_get_data(      uint8_t * dst, size_t size);
+    size_t state_set_data(const uint8_t * src, size_t size);
+
+    size_t state_seq_get_size(llama_seq_id seq_id);
+    size_t state_seq_get_data(llama_seq_id seq_id,       uint8_t * dst, size_t size);
+    size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
+
+    bool state_load_file(
+            const char * filepath,
+           llama_token * tokens_out,
+                size_t   n_token_capacity,
+                size_t * n_token_count_out);
+
+    bool state_save_file(
+            const char * filepath,
+     const llama_token * tokens,
+                size_t   n_token_count);
+
+    size_t state_seq_load_file(
+          llama_seq_id   seq_id,
+            const char * filepath,
+           llama_token * tokens_out,
+                size_t   n_token_capacity,
+                size_t * n_token_count_out);
+
+    size_t state_seq_save_file(
+          llama_seq_id   seq_id,
+            const char * filepath,
+     const llama_token * tokens,
+                size_t   n_token_count);
+
+    //
+    // perf
+    //
+
+    llama_perf_context_data perf_get_data() const;
+    void perf_reset();
+
+private:
+    //
+    // output
+    //
+
+    // Make sure enough space is available for outputs.
+    // Returns max number of outputs for which space was reserved.
+    int32_t output_reserve(int32_t n_outputs);
+
+    // make the outputs have the same order they had in the user-provided batch
+    // TODO: maybe remove this
+    void output_reorder();
 
+    //
+    // graph
+    //
+
+    int32_t graph_max_nodes() const;
+
+    // zero-out inputs and create the ctx_compute for the compute graph
+    ggml_cgraph * graph_init();
+
+    llm_graph_result_ptr graph_build(
+            ggml_context * ctx,
+             ggml_cgraph * gf,
+      const llama_ubatch & ubatch,
+          llm_graph_type   gtype);
+
+    // returns the result of ggml_backend_sched_graph_compute_async execution
+    ggml_status graph_compute(
+            ggml_cgraph * gf,
+                   bool   batched);
+
+    llm_graph_cb graph_get_cb() const;
+
+    // used by kv_self_update()
+    ggml_tensor * build_rope_shift(
+        ggml_context * ctx0,
+        ggml_tensor * cur,
+        ggml_tensor * shift,
+        ggml_tensor * factors,
+        ggml_backend_buffer * bbuf) const;
+
+    llm_graph_result_ptr build_kv_self_shift(
+            ggml_context * ctx0,
+            ggml_cgraph * gf) const;
+
+    llm_graph_result_ptr build_kv_self_defrag(
+            ggml_context * ctx0,
+            ggml_cgraph * gf) const;
+
+    // TODO: read/write lora adapters and cvec
+    size_t state_write_data(llama_io_write_i & io);
+    size_t state_read_data (llama_io_read_i  & io);
+
+    size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
+    size_t state_seq_read_data (llama_io_read_i  & io, llama_seq_id seq_id);
+
+    //
+    // members
+    //
+
+    const llama_model & model;
+
+    llama_cparams       cparams;
+    llama_adapter_cvec  cvec;
+    llama_adapter_loras loras;
+    llama_sbatch        sbatch;
+
+    llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
+
+    std::unique_ptr<llama_kv_cache_unified> kv_self;
+
+    // TODO: remove
     bool logits_all = false;
 
+    // decode output (2-dimensional array: [n_outputs][n_vocab])
+    size_t  logits_size = 0; // capacity (of floats) for logits
+    float * logits      = nullptr;
+
     // embeddings output (2-dimensional array: [n_outputs][n_embd])
     // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
     size_t  embd_size = 0; // capacity (of floats) for embeddings
@@ -72,57 +216,47 @@ struct llama_context {
     // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
     std::map<llama_seq_id, std::vector<float>> embd_seq;
 
-    // whether we are computing encoder output or decoder output
-    bool is_encoding = false;
+    int32_t n_outputs     = 0; // number of actually-used outputs in the current ubatch or last logical batch
+    int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
 
-    // TODO: find a better way to accommodate mutli-dimension position encoding methods
-    // number of position id each token get, 1 for each token in most cases.
-    // when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
-    int n_pos_per_token = 1;
-
-    // output of the encoder part of the encoder-decoder models
-    std::vector<float> embd_enc;
-    std::vector<std::set<llama_seq_id>> seq_ids_enc;
+    std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
 
-    // memory buffers used to evaluate the model
-    std::vector<uint8_t> buf_compute_meta;
     ggml_backend_sched_ptr sched;
 
+    ggml_backend_t backend_cpu = nullptr;
+    std::vector<ggml_backend_ptr> backends;
+
+    ggml_context_ptr ctx_compute;
+
+    ggml_threadpool_t threadpool       = nullptr;
+    ggml_threadpool_t threadpool_batch = nullptr;
+
     ggml_abort_callback abort_callback      = nullptr;
     void *              abort_callback_data = nullptr;
 
-    // input tensors
-    struct ggml_tensor * inp_tokens;        // I32 [n_batch]
-    struct ggml_tensor * inp_embd;          // F32 [n_embd, n_batch]
-    struct ggml_tensor * inp_pos;           // I32 [n_batch]
-    struct ggml_tensor * inp_out_ids;       // I32 [n_outputs]
-    struct ggml_tensor * inp_KQ_mask;       // F32 [kv_size, n_batch]
-    struct ggml_tensor * inp_KQ_mask_swa;   // F32 [kv_size, n_batch]
-    struct ggml_tensor * inp_K_shift;       // I32 [kv_size]
-    struct ggml_tensor * inp_mean;          // F32 [n_batch, n_batch]
-    struct ggml_tensor * inp_cls;           // I32 [n_batch]
-    struct ggml_tensor * inp_s_copy;        // I32 [kv_size]
-    struct ggml_tensor * inp_s_mask;        // F32 [1, n_kv]
-    struct ggml_tensor * inp_s_seq;         // I32 [n_kv, n_batch]
-    struct ggml_tensor * inp_pos_bucket;    // I32 [n_batch|n_kv, n_batch]
-    struct ggml_tensor * inp_embd_enc;      // F32 [n_embd, n_outputs_enc]
-    struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
-};
+    std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
 
-// TODO: make these methods of llama_context
-void llama_set_k_shift(struct llama_context & lctx);
+    // buffer types used for the compute buffer of each backend
+    std::vector<ggml_backend_t>             backend_ptrs;
+    std::vector<ggml_backend_buffer_type_t> backend_buft;
 
-void llama_set_s_copy(struct llama_context & lctx);
+    // memory buffers used to evaluate the model
+    std::vector<uint8_t> buf_compute_meta;
 
-void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
+    // host buffer for the model output (logits and embeddings)
+    ggml_backend_buffer_ptr buf_output;
 
-// Make sure enough space is available for outputs.
-// Returns max number of outputs for which space was reserved.
-size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
+    bool has_evaluated_once = false;
 
-// make the outputs have the same order they had in the user-provided batch
-void llama_output_reorder(struct llama_context & ctx);
+    // perf
+    mutable int64_t t_start_us  = 0;
+    mutable int64_t t_load_us   = 0;
+    mutable int64_t t_p_eval_us = 0;
+    mutable int64_t t_eval_us   = 0;
 
-// For internal test use
-// TODO: remove
-const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);
+    mutable int64_t t_compute_start_us = 0;
+    mutable int64_t n_queued_tokens    = 0;
+
+    mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
+    mutable int32_t n_eval   = 0; // number of eval calls
+};

+ 1695 - 0
src/llama-graph.cpp

@@ -0,0 +1,1695 @@
+#include "llama-graph.h"
+
+#include "llama-impl.h"
+#include "llama-batch.h"
+#include "llama-cparams.h"
+#include "llama-kv-cache.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstring>
+
+static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
+    // TODO move to hparams if a T5 variant appears that uses a different value
+    const int64_t max_distance = 128;
+
+    if (bidirectional) {
+        n_buckets >>= 1;
+    }
+
+    const int64_t max_exact = n_buckets >> 1;
+
+    int32_t relative_position = x - y;
+    int32_t relative_bucket = 0;
+
+    if (bidirectional) {
+        relative_bucket += (relative_position > 0) * n_buckets;
+        relative_position = abs(relative_position);
+    } else {
+        relative_position = -std::min<int32_t>(relative_position, 0);
+    }
+
+    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
+    relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
+    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
+
+    return relative_bucket;
+}
+
+void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
+    if (ubatch->token) {
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
+    }
+
+    if (ubatch->embd) {
+        const int64_t n_embd   = embd->ne[0];
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
+    }
+}
+
+void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
+    if (ubatch->pos && pos) {
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos));
+    }
+}
+
+void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
+    if (pos_bucket) {
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
+        GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+
+        int32_t * data = (int32_t *) pos_bucket->data;
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                for (int i = 0; i < n_tokens; ++i) {
+                    data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
+                }
+            }
+        }
+    }
+}
+
+void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
+    if (pos_bucket) {
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
+        GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+
+        int32_t * data = (int32_t *) pos_bucket->data;
+
+        const int64_t n_kv = kv_self->n;
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                for (int i = 0; i < n_kv; ++i) {
+                    data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
+                }
+            }
+        }
+    }
+}
+
+void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
+    if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
+        //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
+
+        if (!out_ids) {
+            LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
+        } else {
+            const int64_t n_tokens = ubatch->n_tokens;
+
+            GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
+            int32_t * data = (int32_t *) out_ids->data;
+
+            if (n_outputs == n_tokens) {
+                for (int i = 0; i < n_tokens; ++i) {
+                    data[i] = i;
+                }
+            } else if (ubatch->output) {
+                int32_t n_outputs = 0;
+                for (int i = 0; i < n_tokens; ++i) {
+                    if (ubatch->output[i]) {
+                        data[n_outputs++] = i;
+                    }
+                }
+                // the graph needs to have been passed the correct number of outputs
+                GGML_ASSERT(n_outputs == n_outputs);
+            } else if (n_outputs == 1) {
+                // only keep last output
+                data[0] = n_tokens - 1;
+            } else {
+                GGML_ASSERT(n_outputs == 0);
+            }
+        }
+    }
+}
+
+void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
+        const int64_t n_tokens     = ubatch->n_tokens;
+        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+        const int64_t n_seqs       = ubatch->n_seqs;
+
+        GGML_ASSERT(mean);
+        GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
+
+        float * data = (float *) mean->data;
+        memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
+
+        std::vector<uint64_t> sum(n_tokens, 0);
+
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+            // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
+            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
+
+            sum[seq_id] += ubatch->n_seq_tokens;
+        }
+
+        std::vector<float> div(n_tokens, 0.0f);
+        for (int i = 0; i < n_tokens; ++i) {
+            const uint64_t s = sum[i];
+            if (s > 0) {
+                div[i] = 1.0f/float(s);
+            }
+        }
+
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
+            }
+        }
+    }
+}
+
+void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
+    if (cparams.embeddings && (
+                cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
+                cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
+        const int64_t n_tokens     = ubatch->n_tokens;
+        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+        const int64_t n_seqs       = ubatch->n_seqs;
+
+        GGML_ASSERT(cls);
+        GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
+
+        uint32_t * data = (uint32_t *) cls->data;
+        memset(cls->data, 0, n_tokens * ggml_element_size(cls));
+
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+            // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
+            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
+
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
+
+                if (pos == 0) {
+                    data[seq_id] = s*n_seq_tokens + i;
+                }
+            }
+        }
+    }
+
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
+        const int64_t n_tokens     = ubatch->n_tokens;
+        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+        const int64_t n_seqs       = ubatch->n_seqs;
+
+        GGML_ASSERT(cls);
+        GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
+
+        uint32_t * data = (uint32_t *) cls->data;
+        memset(cls->data, 0, n_tokens * ggml_element_size(cls));
+
+        std::vector<int> last_pos(n_tokens, -1);
+        std::vector<int> last_row(n_tokens, -1);
+
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+            // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
+            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
+
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
+
+                if (pos >= last_pos[seq_id]) {
+                    last_pos[seq_id] = pos;
+                    last_row[seq_id] = s*n_seq_tokens + i;
+                }
+            }
+        }
+
+        for (int i = 0; i < n_tokens; ++i) {
+            if (last_row[i] >= 0) {
+                data[i] = last_row[i];
+            }
+        }
+    }
+}
+
+void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
+    GGML_UNUSED(ubatch);
+
+    const int64_t n_kv = kv_self->n;
+
+    if (s_copy) {
+        GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
+        int32_t * data = (int32_t *) s_copy->data;
+
+        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
+        for (uint32_t i = 0; i < n_kv; ++i) {
+            const uint32_t  cell_id = i + kv_self->head;
+
+            //////////////////////////////////////////////
+            // TODO: this should not mutate the KV cache !
+            llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
+
+            // prevent out-of-bound sources
+            if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
+                kv_cell.src = cell_id;
+            }
+
+            data[i] = kv_cell.src;
+
+            // TODO: do not mutate the KV cache
+            // ensure copy only happens once
+            if (kv_cell.src != (int32_t) cell_id) {
+                kv_cell.src = cell_id;
+            }
+        }
+    }
+}
+
+void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
+    GGML_UNUSED(ubatch);
+
+    const int64_t n_kv = kv_self->n;
+
+    if (s_mask) {
+        GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
+        float * data = (float *) s_mask->data;
+
+        // clear unused states
+        for (int i = 0; i < n_kv; ++i) {
+            const uint32_t  cell_id = i + kv_self->head;
+
+            //////////////////////////////////////////////
+            // TODO: this should not mutate the KV cache !
+            llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
+
+            data[i] = (float) (kv_cell.src >= 0);
+
+            // only clear once
+            if (kv_cell.src < 0) {
+                kv_cell.src = cell_id;
+            }
+        }
+    }
+}
+
+void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
+    GGML_UNUSED(ubatch);
+
+    if (cross_embd && !cross->v_embd.empty()) {
+        assert(cross_embd->type == GGML_TYPE_F32);
+
+        ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
+    }
+}
+
+void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
+    if (kq_mask) {
+        if (cparams.causal_attn) {
+            const int64_t n_kv         = ubatch->n_tokens;
+            const int64_t n_tokens     = ubatch->n_tokens;
+            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+            const int64_t n_seqs       = ubatch->n_seqs;
+
+            GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
+            float * data = (float *) kq_mask->data;
+
+            for (int h = 0; h < 1; ++h) {
+                for (int s1 = 0; s1 < n_seqs; ++s1) {
+                    const llama_seq_id seq_id = ubatch->seq_id[s1][0];
+
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const int32_t tj = s1*n_seq_tokens + j;
+
+                        for (int s0 = 0; s0 < n_seqs; ++s0) {
+                            for (int i = 0; i < n_seq_tokens; ++i) {
+                                const int32_t ti = s0*n_seq_tokens + i;
+                                float f = -INFINITY;
+
+                                for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
+                                    if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
+                                        if (hparams.use_alibi) {
+                                            f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
+                                        } else {
+                                            f = 0.0f;
+                                        }
+                                        break;
+                                    }
+                                }
+
+                                data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
+                            }
+                        }
+                    }
+                }
+            }
+        } else {
+            const int64_t n_tokens     = ubatch->n_tokens;
+            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+            const int64_t n_seqs       = ubatch->n_seqs;
+            const int64_t n_stride     = ubatch->n_tokens;
+
+            GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
+
+            float * data = (float *) kq_mask->data;
+
+            for (int h = 0; h < 1; ++h) {
+                for (int s1 = 0; s1 < n_seqs; ++s1) {
+                    const llama_seq_id seq_id = ubatch->seq_id[s1][0];
+
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const int32_t tj = s1*n_seq_tokens + j;
+
+                        for (int s0 = 0; s0 < n_seqs; ++s0) {
+                            for (int i = 0; i < n_seq_tokens; ++i) {
+                                const int32_t ti = s0*n_seq_tokens + i;
+                                float f = -INFINITY;
+
+                                for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
+                                    if (ubatch->seq_id[s0][s] == seq_id) {
+                                        if (hparams.use_alibi) {
+                                            f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
+                                        } else {
+                                            f = 0.0f;
+                                        }
+                                        break;
+                                    }
+                                }
+
+                                data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
+                            }
+                        }
+
+                        for (int i = n_tokens; i < n_stride; ++i) {
+                            data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
+    if (self_kq_mask || self_kq_mask_swa) {
+        // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
+        if (cparams.causal_attn) {
+            const int64_t n_kv         = kv_self->n;
+            const int64_t n_tokens     = ubatch->n_tokens;
+            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+            const int64_t n_seqs       = ubatch->n_seqs;
+
+            float * data     = nullptr;
+            float * data_swa = nullptr;
+
+            if (self_kq_mask) {
+                GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+                data = (float *) self_kq_mask->data;
+            }
+
+            if (self_kq_mask_swa) {
+                GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
+                data_swa = (float *) self_kq_mask_swa->data;
+            }
+
+            // For causal attention, use only the previous KV cells
+            // of the correct sequence for each token of the ubatch.
+            // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
+            for (int h = 0; h < 1; ++h) {
+                for (int s = 0; s < n_seqs; ++s) {
+                    const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
+
+                        for (int i = 0; i < n_kv; ++i) {
+                            float f;
+                            if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
+                                f = -INFINITY;
+                            } else {
+                                if (hparams.use_alibi) {
+                                    f = -std::abs(kv_self->cells[i].pos - pos);
+                                } else {
+                                    f = 0.0f;
+                                }
+                            }
+
+                            if (data) {
+                                data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+                            }
+
+                            // may need to cut off old tokens for sliding window
+                            if (data_swa) {
+                                if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
+                                    f = -INFINITY;
+                                }
+                                data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+                            }
+                        }
+                    }
+                }
+
+                if (data) {
+                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                        for (int j = 0; j < n_kv; ++j) {
+                            data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                        }
+                    }
+                }
+
+                if (data_swa) {
+                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                        for (int j = 0; j < n_kv; ++j) {
+                            data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                        }
+                    }
+                }
+            }
+        } else {
+            const int64_t n_tokens     = ubatch->n_tokens;
+            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+            const int64_t n_seqs       = ubatch->n_seqs;
+            // when using kv cache, the mask needs to match the kv cache size
+            const int64_t n_stride     = n_tokens;
+
+            GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+
+            float * data = (float *) self_kq_mask->data;
+
+            for (int h = 0; h < 1; ++h) {
+                for (int s1 = 0; s1 < n_seqs; ++s1) {
+                    const llama_seq_id seq_id = ubatch->seq_id[s1][0];
+
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const int32_t tj = s1*n_seq_tokens + j;
+
+                        for (int s0 = 0; s0 < n_seqs; ++s0) {
+                            for (int i = 0; i < n_seq_tokens; ++i) {
+                                const int32_t ti = s0*n_seq_tokens + i;
+                                float f = -INFINITY;
+
+                                for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
+                                    if (ubatch->seq_id[s0][s] == seq_id) {
+                                        if (hparams.use_alibi) {
+                                            f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
+                                        } else {
+                                            f = 0.0f;
+                                        }
+                                        break;
+                                    }
+                                }
+
+                                data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
+                            }
+                        }
+
+                        for (int i = n_tokens; i < n_stride; ++i) {
+                            data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
+    if (cross_kq_mask) {
+        const int64_t n_enc    = cross_kq_mask->ne[0];
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
+        GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+
+        float * data = (float *) cross_kq_mask->data;
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                for (int i = 0; i < n_enc; ++i) {
+                    float f = -INFINITY;
+                    for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
+                        const llama_seq_id seq_id = ubatch->seq_id[j][s];
+                        if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
+                            f = 0.0f;
+                        }
+                    }
+                    data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
+                }
+            }
+
+            for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                for (int j = 0; j < n_enc; ++j) {
+                    data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
+                }
+            }
+        }
+    }
+}
+
+//
+// llm_graph_context
+//
+
+llm_graph_context::llm_graph_context(const llm_graph_params & params) :
+    arch             (params.arch),
+    hparams          (params.hparams),
+    cparams          (params.cparams),
+    ubatch           (params.ubatch),
+    n_embd           (hparams.n_embd),
+    n_layer          (hparams.n_layer),
+    n_rot            (hparams.n_rot),
+    n_ctx            (cparams.n_ctx),
+    n_ctx_per_seq    (cparams.n_ctx / cparams.n_seq_max),
+    n_head           (hparams.n_head()),
+    n_head_kv        (hparams.n_head_kv()),
+    n_embd_head_k    (hparams.n_embd_head_k),
+    n_embd_k_gqa     (hparams.n_embd_k_gqa()),
+    n_embd_head_v    (hparams.n_embd_head_v),
+    n_embd_v_gqa     (hparams.n_embd_v_gqa()),
+    n_expert         (hparams.n_expert),
+    n_expert_used    (hparams.n_expert_used),
+    freq_base        (cparams.rope_freq_base),
+    freq_scale       (cparams.rope_freq_scale),
+    ext_factor       (cparams.yarn_ext_factor),
+    attn_factor      (cparams.yarn_attn_factor),
+    beta_fast        (cparams.yarn_beta_fast),
+    beta_slow        (cparams.yarn_beta_slow),
+    norm_eps         (hparams.f_norm_eps),
+    norm_rms_eps     (hparams.f_norm_rms_eps),
+    n_tokens         (ubatch.n_tokens),
+    n_outputs        (params.n_outputs),
+    n_ctx_orig       (cparams.n_ctx_orig_yarn),
+    pooling_type     (cparams.pooling_type),
+    rope_type        (hparams.rope_type),
+    ctx0             (params.ctx),
+    sched            (params.sched),
+    backend_cpu      (params.backend_cpu),
+    cvec             (params.cvec),
+    loras            (params.loras),
+    memory           (params.memory),
+    cross            (params.cross),
+    cb_func          (params.cb),
+    res              (std::make_unique<llm_graph_result>()) {
+    }
+
+int64_t llm_graph_context::n_pos_per_token() const {
+    return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
+}
+
+void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
+    if (cb_func) {
+        cb_func(ubatch, cur, name, il);
+    }
+}
+
+ggml_tensor * llm_graph_context::build_cvec(
+         ggml_tensor * cur,
+                 int   il) const {
+    return cvec->apply_to(ctx0, cur, il);
+}
+
+ggml_tensor * llm_graph_context::build_lora_mm(
+          ggml_tensor * w,
+          ggml_tensor * cur) const {
+    ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
+
+    for (const auto & lora : *loras) {
+        llama_adapter_lora_weight * lw = lora.first->get_weight(w);
+        if (lw == nullptr) {
+            continue;
+        }
+
+        const float adapter_scale = lora.second;
+        const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
+
+        ggml_tensor * ab_cur = ggml_mul_mat(
+                ctx0, lw->b,
+                ggml_mul_mat(ctx0, lw->a, cur)
+                );
+
+        ab_cur = ggml_scale(ctx0, ab_cur, scale);
+        res = ggml_add(ctx0, res, ab_cur);
+    }
+
+    return res;
+}
+
+ggml_tensor * llm_graph_context::build_lora_mm_id(
+          ggml_tensor * w,   // ggml_tensor * as
+          ggml_tensor * cur, // ggml_tensor * b
+          ggml_tensor * ids) const {
+    ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
+    for (const auto & lora : *loras) {
+        llama_adapter_lora_weight * lw = lora.first->get_weight(w);
+        if (lw == nullptr) {
+            continue;
+        }
+
+        const float alpha = lora.first->alpha;
+        const float rank  = (float) lw->b->ne[0];
+        const float scale = alpha ? lora.second * alpha / rank : lora.second;
+
+        ggml_tensor * ab_cur = ggml_mul_mat_id(
+                ctx0, lw->b,
+                ggml_mul_mat_id(ctx0, lw->a, cur, ids),
+                ids
+                );
+
+        ab_cur = ggml_scale(ctx0, ab_cur, scale);
+        res = ggml_add(ctx0, res, ab_cur);
+    }
+
+    return res;
+}
+
+ggml_tensor * llm_graph_context::build_norm(
+         ggml_tensor * cur,
+         ggml_tensor * mw,
+         ggml_tensor * mb,
+       llm_norm_type   type,
+                 int   il) const {
+    switch (type) {
+        case LLM_NORM:       cur = ggml_norm    (ctx0, cur, hparams.f_norm_eps);     break;
+        case LLM_NORM_RMS:   cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
+        case LLM_NORM_GROUP:
+            {
+                cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
+                cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
+                cur = ggml_reshape_2d(ctx0, cur, cur->ne[0],    cur->ne[2]);
+            } break;
+    }
+
+    if (mw || mb) {
+        cb(cur, "norm", il);
+    }
+
+    if (mw) {
+        cur = ggml_mul(ctx0, cur, mw);
+        if (mb) {
+            cb(cur, "norm_w", il);
+        }
+    }
+
+    if (mb) {
+        cur = ggml_add(ctx0, cur, mb);
+    }
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_ffn(
+         ggml_tensor * cur,
+         ggml_tensor * up,
+         ggml_tensor * up_b,
+         ggml_tensor * up_s,
+         ggml_tensor * gate,
+         ggml_tensor * gate_b,
+         ggml_tensor * gate_s,
+         ggml_tensor * down,
+         ggml_tensor * down_b,
+         ggml_tensor * down_s,
+         ggml_tensor * act_scales,
+     llm_ffn_op_type   type_op,
+   llm_ffn_gate_type   type_gate,
+                 int   il) const {
+    ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
+    cb(tmp, "ffn_up", il);
+
+    if (up_b) {
+        tmp = ggml_add(ctx0, tmp, up_b);
+        cb(tmp, "ffn_up_b", il);
+    }
+
+    if (up_s) {
+        tmp = ggml_mul(ctx0, tmp, up_s);
+        cb(tmp, "ffn_up_s", il);
+    }
+
+    if (gate) {
+        switch (type_gate) {
+            case LLM_FFN_SEQ:
+                {
+                    cur = build_lora_mm(gate, tmp);
+                    cb(cur, "ffn_gate", il);
+                } break;
+            case LLM_FFN_PAR:
+                {
+                    cur = build_lora_mm(gate, cur);
+                    cb(cur, "ffn_gate", il);
+                } break;
+        }
+
+        if (gate_b) {
+            cur = ggml_add(ctx0, cur, gate_b);
+            cb(cur, "ffn_gate_b", il);
+        }
+
+        if (gate_s) {
+            cur = ggml_mul(ctx0, cur, gate_s);
+            cb(cur, "ffn_gate_s", il);
+        }
+
+    } else {
+        cur = tmp;
+    }
+
+    switch (type_op) {
+        case LLM_FFN_SILU:
+            {
+                cur = ggml_silu(ctx0, cur);
+                cb(cur, "ffn_silu", il);
+            } break;
+        case LLM_FFN_GELU:
+            {
+                cur = ggml_gelu(ctx0, cur);
+                cb(cur, "ffn_gelu", il);
+                if (act_scales != NULL) {
+                    cur = ggml_div(ctx0, cur, act_scales);
+                    cb(cur, "ffn_act", il);
+                }
+            } break;
+        case LLM_FFN_RELU:
+            {
+                cur = ggml_relu(ctx0, cur);
+                cb(cur, "ffn_relu", il);
+            } break;
+        case LLM_FFN_RELU_SQR:
+            {
+                cur = ggml_relu(ctx0, cur);
+                cb(cur, "ffn_relu", il);
+
+                cur = ggml_sqr(ctx0, cur);
+                cb(cur, "ffn_sqr(relu)", il);
+            } break;
+        case LLM_FFN_SWIGLU:
+            {
+                // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
+                int64_t split_point = cur->ne[0] / 2;
+                ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
+                ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
+
+                x0 = ggml_silu(ctx0, x0);
+                cb(cur, "ffn_silu", il);
+
+                cur = ggml_mul(ctx0, x0, x1);
+                cb(cur, "ffn_mul", il);
+            } break;
+    }
+
+    if (type_gate == LLM_FFN_PAR) {
+        cur = ggml_mul(ctx0, cur, tmp);
+        cb(cur, "ffn_gate_par", il);
+    }
+
+    if (down) {
+        cur = build_lora_mm(down, cur);
+    }
+
+    if (down_b) {
+        cb(cur, "ffn_down", il);
+    }
+
+    if (down_b) {
+        cur = ggml_add(ctx0, cur, down_b);
+    }
+
+    if (down_s) {
+        cur = ggml_mul(ctx0, cur, down_s);
+        cb(cur, "ffn_down_s", il);
+    }
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_moe_ffn(
+         ggml_tensor * cur,
+         ggml_tensor * gate_inp,
+         ggml_tensor * up_exps,
+         ggml_tensor * gate_exps,
+         ggml_tensor * down_exps,
+         ggml_tensor * exp_probs_b,
+             int64_t   n_expert,
+             int64_t   n_expert_used,
+     llm_ffn_op_type   type_op,
+                bool   norm_w,
+                bool   scale_w,
+               float   w_scale,
+         llama_expert_gating_func_type gating_op,
+                 int   il) const {
+    int64_t n_embd = cur->ne[0];
+    int64_t n_tokens = cur->ne[1];
+
+    ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
+    cb(logits, "ffn_moe_logits", il);
+
+    ggml_tensor * probs = nullptr;
+    switch (gating_op) {
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
+            {
+                probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
+            } break;
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
+            {
+                probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
+            } break;
+        default:
+            GGML_ABORT("fatal error");
+    }
+    cb(probs, "ffn_moe_probs", il);
+
+    // add experts selection bias - introduced in DeepSeek V3
+    // leave probs unbiased as it's later used to get expert weights
+    ggml_tensor * selection_probs = probs;
+    if (exp_probs_b != nullptr) {
+        selection_probs = ggml_add(ctx0, probs, exp_probs_b);
+        cb(selection_probs, "ffn_moe_probs_biased", il);
+    }
+
+    // select experts
+    ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
+    cb(selected_experts->src[0], "ffn_moe_argsort", il);
+    cb(selected_experts, "ffn_moe_topk", il);
+
+    ggml_tensor * weights = ggml_get_rows(ctx0,
+            ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
+    cb(weights, "ffn_moe_weights", il);
+
+    if (norm_w) {
+        weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
+
+        ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
+        cb(weights_sum, "ffn_moe_weights_sum", il);
+
+        weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
+        cb(weights, "ffn_moe_weights_norm", il);
+
+        weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
+    }
+    if (scale_w) {
+        weights = ggml_scale(ctx0, weights, w_scale);
+        cb(weights, "ffn_moe_weights_scaled", il);
+    }
+
+    cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
+    ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    cb(up, "ffn_moe_up", il);
+
+    ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    cb(gate, "ffn_moe_gate", il);
+
+    switch (type_op) {
+        case LLM_FFN_SILU:
+            {
+                gate = ggml_silu(ctx0, gate);
+                cb(gate, "ffn_moe_silu", il);
+            } break;
+        case LLM_FFN_GELU:
+            {
+                gate = ggml_gelu(ctx0, gate);
+                cb(gate, "ffn_moe_gelu", il);
+            } break;
+        default:
+            GGML_ABORT("fatal error");
+    }
+
+    ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
+    cb(par, "ffn_moe_gate_par", il);
+
+    ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
+    cb(experts, "ffn_moe_down", il);
+
+    experts = ggml_mul(ctx0, experts, weights);
+
+    // aggregate experts
+    ggml_tensor * moe_out = nullptr;
+    for (int i = 0; i < n_expert_used; ++i) {
+        ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
+                experts->nb[2], i*experts->nb[1]);
+
+        if (i == 0) {
+            moe_out = cur_expert;
+        } else {
+            moe_out = ggml_add(ctx0, moe_out, cur_expert);
+        }
+    }
+
+    if (n_expert_used == 1) {
+        // avoid returning a non-contiguous tensor
+        moe_out = ggml_cont(ctx0, moe_out);
+    }
+
+    return moe_out;
+}
+
+// input embeddings with optional lora
+ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
+    const int64_t n_embd = hparams.n_embd;
+
+    auto inp = std::make_unique<llm_graph_input_embd>();
+
+    ggml_tensor * cur = nullptr;
+
+    if (ubatch.token) {
+        inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
+        //cb(inp->tokens, "inp_tokens", -1);
+        ggml_set_input(inp->tokens);
+
+        cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
+
+        // apply lora for embedding tokens if needed
+        for (const auto & lora : *loras) {
+            llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
+            if (lw == nullptr) {
+                continue;
+            }
+
+            const float adapter_scale = lora.second;
+            const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
+
+            ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
+                        ctx0, lw->b, // non-transposed lora_b
+                        ggml_get_rows(ctx0, lw->a, inp->tokens)
+                        ), scale);
+
+            cur = ggml_add(ctx0, cur, inpL_delta);
+        }
+    } else {
+        inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
+        ggml_set_input(inp->embd);
+
+        cur = inp->embd;
+    }
+
+    // For Granite architecture
+    if (hparams.f_embedding_scale != 0.0f) {
+        cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
+    }
+
+    cb(cur, "inp_embd", -1);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_pos() const {
+    auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
+
+    auto & cur = inp->pos;
+
+    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_out_ids() const {
+    auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
+
+    auto & cur = inp->out_ids;
+
+    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_mean() const {
+    auto inp = std::make_unique<llm_graph_input_mean>(cparams);
+
+    auto & cur = inp->mean;
+
+    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_cls() const {
+    auto inp = std::make_unique<llm_graph_input_cls>(cparams);
+
+    auto & cur = inp->cls;
+
+    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_s_copy() const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
+
+    const auto n_kv = kv_self->n;
+
+    auto & cur = inp->s_copy;
+
+    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_s_mask() const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
+
+    const auto n_kv = kv_self->n;
+
+    auto & cur = inp->s_mask;
+
+    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
+    auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
+
+    auto & cur = inp->cross_embd;
+
+    // if we have the output embeddings from the encoder, use them directly
+    // TODO: needs more work to be correct, for now just use the tensor shape
+    //if (cross->t_embd) {
+    //    cur = ggml_view_tensor(ctx0, cross->t_embd);
+
+    //    return cur;
+    //}
+
+    const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
+    const auto n_enc  = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
+
+    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
+    auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
+
+    auto & cur = inp->pos_bucket;
+
+    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
+
+    const auto n_kv = kv_self->n;
+
+    auto & cur = inp->pos_bucket;
+
+    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
+    ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
+    cb(pos_bucket_1d, "pos_bucket_1d", -1);
+
+    ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
+
+    pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
+    pos_bias = ggml_permute   (ctx0, pos_bias, 2, 0, 1, 3);
+    pos_bias = ggml_cont      (ctx0, pos_bias);
+
+    cb(pos_bias, "pos_bias", -1);
+
+    return pos_bias;
+}
+
+ggml_tensor * llm_graph_context::build_attn_mha(
+         ggml_cgraph * gf,
+         ggml_tensor * q,
+         ggml_tensor * k,
+         ggml_tensor * v,
+         ggml_tensor * kq_b,
+         ggml_tensor * kq_mask,
+             bool      v_trans,
+             float     kq_scale) const {
+  //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+  //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+  //const int64_t n_head    = hparams.n_head(il);
+  //const int64_t n_head_kv = hparams.n_head_kv(il);
+
+  //const auto & n_embd_head_k = hparams.n_embd_head_k;
+  //const auto & n_embd_head_v = hparams.n_embd_head_v;
+
+    const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
+
+    const auto n_tokens = q->ne[1];
+    const auto n_head   = q->ne[2];
+    const auto n_kv     = k->ne[1];
+
+    ggml_tensor * cur;
+
+    // TODO: replace hardcoded padding with ggml-provided padding
+    if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
+        GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
+
+        if (v_trans) {
+            v = ggml_transpose(ctx0, v);
+        }
+
+        cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
+                                  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
+
+        ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
+
+        cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
+    } else {
+        ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+
+        // note: this op tends to require high floating point range
+        //       while for some models F16 is enough, for others it is not, so we default to F32 here
+        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+
+        if (arch == LLM_ARCH_GROK) {
+            // need to do the following:
+            // multiply by attn_output_multiplyer of 0.08838834764831845
+            // and then :
+            // kq = 30 * tanh(kq / 30)
+            // before the softmax below
+
+            kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
+            kq = ggml_scale(ctx0, kq, 30);
+        }
+
+        if (hparams.attn_soft_cap) {
+            kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
+            kq = ggml_tanh (ctx0, kq);
+            kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
+        }
+
+        if (kq_b) {
+            kq = ggml_add(ctx0, kq, kq_b);
+        }
+
+        kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+
+        if (!v_trans) {
+            // note: avoid this branch
+            v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
+        }
+
+        ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+
+        ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+
+        cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+
+        if (!cparams.offload_kqv) {
+            // all nodes between the KV store and the attention output are run on the CPU
+            ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
+        }
+    }
+
+    ggml_build_forward_expand(gf, cur);
+
+    return cur;
+}
+
+llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
+    auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
+
+    // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
+    inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+    //cb(inp_kq_mask, "KQ_mask", -1);
+    ggml_set_input(inp->kq_mask);
+
+    inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
+
+    return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_attn(
+        llm_graph_input_attn_no_cache * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * wo,
+        ggml_tensor * wo_b,
+        ggml_tensor * q_cur,
+        ggml_tensor * k_cur,
+        ggml_tensor * v_cur,
+        ggml_tensor * kq_b,
+            float     kq_scale,
+            int       il) const {
+    GGML_UNUSED(n_tokens);
+
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    ggml_build_forward_expand(gf, q_cur);
+    ggml_build_forward_expand(gf, k_cur);
+    ggml_build_forward_expand(gf, v_cur);
+
+    const auto & kq_mask = inp->get_kq_mask();
+
+    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+    //cb(q, "q", il);
+
+    ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
+    //cb(k, "k", il);
+
+    ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
+    //cb(k, "v", il);
+
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
+
+    cb(cur, "kqv_out", il);
+
+    if (wo) {
+        cur = build_lora_mm(wo, cur);
+    }
+
+    if (wo_b) {
+        //cb(cur, "kqv_wo", il);
+    }
+
+    if (wo_b) {
+        cur = ggml_add(ctx0, cur, wo_b);
+    }
+
+    return cur;
+}
+
+llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(
+                bool   causal,
+                bool   swa) const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
+
+    const auto n_kv = kv_self->n;
+
+    inp->self_kq_mask = causal
+        ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
+        : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+    //cb(inp->self_kq_mask, "KQ_mask", -1);
+    ggml_set_input(inp->self_kq_mask);
+
+    inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+
+    if (swa) {
+        GGML_ASSERT(hparams.n_swa > 0);
+
+        inp->self_kq_mask_swa = causal
+            ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
+            : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
+        ggml_set_input(inp->self_kq_mask_swa);
+
+        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
+    }
+
+    return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_attn(
+        llm_graph_input_attn_kv_unified * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * wo,
+        ggml_tensor * wo_b,
+        ggml_tensor * q_cur,
+        ggml_tensor * k_cur,
+        ggml_tensor * v_cur,
+        ggml_tensor * kq_b,
+            float     kq_scale,
+            int       il) const {
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    ggml_build_forward_expand(gf, q_cur);
+    ggml_build_forward_expand(gf, k_cur);
+    ggml_build_forward_expand(gf, v_cur);
+
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const auto & n_ctx = cparams.n_ctx;
+
+    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+    const auto n_tokens = q_cur->ne[2];
+
+    const bool v_trans = !cparams.flash_attn;
+
+    // store to KV cache
+    {
+        GGML_ASSERT(!kv_self->recurrent);
+
+        const auto kv_head = kv_self->head;
+
+        GGML_ASSERT(kv_self->size == n_ctx);
+
+        ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
+        //cb(k_cache_view, "k_cache_view", il);
+
+        // note: storing RoPE-ed version of K in the KV cache
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
+
+        assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
+
+        ggml_tensor * v_cache_view = nullptr;
+
+        if (!v_trans) {
+            v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
+        } else {
+            // note: the V cache is transposed when not using flash attention
+            v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
+                    (  n_ctx)*ggml_element_size(kv_self->v_l[il]),
+                    (kv_head)*ggml_element_size(kv_self->v_l[il]));
+
+            v_cur = ggml_transpose(ctx0, v_cur);
+        }
+        //cb(v_cache_view, "v_cache_view", il);
+
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
+    }
+
+    // TODO: improve
+    bool is_sliding = false;
+
+    switch (arch) {
+        case LLM_ARCH_COHERE2:
+            {
+                const int32_t sliding_window_pattern = 4;
+                is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            } break;
+        case LLM_ARCH_GEMMA2:
+            {
+                const int32_t sliding_window_pattern = 2;
+                is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            } break;
+        case LLM_ARCH_GEMMA3:
+            {
+                const int32_t sliding_window_pattern = 6;
+                is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            } break;
+        case LLM_ARCH_PHI3:
+            {
+                is_sliding = hparams.n_swa > 0;
+            } break;
+        default:
+            {
+                is_sliding = false;
+            }
+    };
+
+    const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();
+
+    const auto n_kv = kv_self->n;
+
+    const int64_t n_head_kv = hparams.n_head_kv(il);
+
+    const auto & n_embd_head_k = hparams.n_embd_head_k;
+    const auto & n_embd_head_v = hparams.n_embd_head_v;
+
+    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+    //cb(q, "q", il);
+
+    ggml_tensor * k =
+        ggml_view_3d(ctx0, kv_self->k_l[il],
+                n_embd_head_k, n_kv, n_head_kv,
+                ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
+                ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
+                0);
+    //cb(k, "k", il);
+
+    ggml_tensor * v = !v_trans ?
+        ggml_view_3d(ctx0, kv_self->v_l[il],
+                n_embd_head_v, n_kv, n_head_kv,
+                ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
+                ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
+                0) :
+        ggml_view_3d(ctx0, kv_self->v_l[il],
+                n_kv, n_embd_head_v, n_head_kv,
+                ggml_element_size(kv_self->v_l[il])*n_ctx,
+                ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
+                0);
+
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
+    cb(cur, "kqv_out", il);
+
+    if (wo) {
+        cur = build_lora_mm(wo, cur);
+    }
+
+    if (wo_b) {
+        //cb(cur, "kqv_wo", il);
+    }
+
+    if (wo_b) {
+        cur = ggml_add(ctx0, cur, wo_b);
+    }
+
+    return cur;
+}
+
+llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
+    auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
+
+    const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
+
+    inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+    ggml_set_input(inp->cross_kq_mask);
+
+    inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
+
+    return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_attn(
+        llm_graph_input_attn_cross * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * wo,
+        ggml_tensor * wo_b,
+        ggml_tensor * q_cur,
+        ggml_tensor * k_cur,
+        ggml_tensor * v_cur,
+        ggml_tensor * kq_b,
+            float     kq_scale,
+            int       il) const {
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    ggml_build_forward_expand(gf, q_cur);
+    ggml_build_forward_expand(gf, k_cur);
+    ggml_build_forward_expand(gf, v_cur);
+
+    const auto & kq_mask = inp->get_kq_mask_cross();
+
+    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+    //cb(q, "q", il);
+
+    ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
+    //cb(k, "k", il);
+
+    ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
+    //cb(k, "v", il);
+
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
+
+    cb(cur, "kqv_out", il);
+
+    if (wo) {
+        cur = build_lora_mm(wo, cur);
+    }
+
+    if (wo_b) {
+        //cb(cur, "kqv_wo", il);
+    }
+
+    if (wo_b) {
+        cur = ggml_add(ctx0, cur, wo_b);
+    }
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_copy_mask_state(
+         ggml_cgraph * gf,
+         ggml_tensor * s,
+         ggml_tensor * state_copy,
+         ggml_tensor * state_mask,
+             int32_t   n_state,
+             int32_t   n_seqs) const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    const auto n_kv    = kv_self->n;
+    const auto kv_head = kv_self->head;
+
+    ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
+
+    // copy states
+    // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
+    // this shrinks the tensors's ne[1] to n_kv
+    states = ggml_get_rows(ctx0, states, state_copy);
+
+    // clear states of sequences which are starting at the beginning of this batch
+    // FIXME: zero-out NANs?
+    states = ggml_mul(ctx0, states, state_mask);
+
+    // copy states which won't be changed further (between n_seqs and n_kv)
+    ggml_build_forward_expand(gf,
+        ggml_cpy(ctx0,
+            ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs          )*n_state*ggml_element_size(states)),
+            ggml_view_1d(ctx0, s,      n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
+
+    // the part of the states that will be used and modified
+    return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
+}
+
+ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
+         ggml_cgraph * gf,
+         ggml_tensor * state_copy,
+         ggml_tensor * state_mask,
+  const llama_ubatch & ubatch,
+                 int   il) const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    const auto token_shift_count = hparams.token_shift_count;
+
+    const int64_t n_seqs  = ubatch.n_seqs;
+
+    ggml_tensor * token_shift_all = kv_self->k_l[il];
+
+    ggml_tensor * token_shift = build_copy_mask_state(
+            gf, token_shift_all, state_copy, state_mask,
+            hparams.n_embd_k_s(), n_seqs);
+
+    token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
+
+    return token_shift;
+}
+
+ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
+         ggml_tensor * token_shift,
+  const llama_ubatch & ubatch,
+                 int   il) const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    const auto token_shift_count = hparams.token_shift_count;
+    const auto n_embd = hparams.n_embd;
+
+    const int64_t n_seqs = ubatch.n_seqs;
+
+    const auto kv_head = kv_self->head;
+
+    return ggml_cpy(
+        ctx0,
+        ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
+        ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
+    );
+}
+
+void llm_graph_context::build_pooling(
+        ggml_cgraph * gf,
+        ggml_tensor * cls,
+        ggml_tensor * cls_b,
+        ggml_tensor * cls_out,
+        ggml_tensor * cls_out_b) const {
+    if (!cparams.embeddings) {
+        return;
+    }
+
+    ggml_tensor * inp = res->t_embd;
+
+    //// find result_norm tensor for input
+    //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
+    //    inp = ggml_graph_node(gf, i);
+    //    if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
+    //        break;
+    //    }
+
+    //    inp = nullptr;
+    //}
+
+    GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
+
+    ggml_tensor * cur;
+
+    switch (pooling_type) {
+        case LLAMA_POOLING_TYPE_NONE:
+            {
+                cur = inp;
+            } break;
+        case LLAMA_POOLING_TYPE_MEAN:
+            {
+                ggml_tensor * inp_mean = build_inp_mean();
+                cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
+            } break;
+        case LLAMA_POOLING_TYPE_CLS:
+        case LLAMA_POOLING_TYPE_LAST:
+            {
+                ggml_tensor * inp_cls = build_inp_cls();
+                cur = ggml_get_rows(ctx0, inp, inp_cls);
+            } break;
+        case LLAMA_POOLING_TYPE_RANK:
+            {
+                ggml_tensor * inp_cls = build_inp_cls();
+                inp = ggml_get_rows(ctx0, inp, inp_cls);
+
+                // classification head
+                // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
+                GGML_ASSERT(cls   != nullptr);
+                GGML_ASSERT(cls_b != nullptr);
+
+                cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
+                cur = ggml_tanh(ctx0, cur);
+
+                // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+                // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
+                if (cls_out) {
+                    GGML_ASSERT(cls_out_b != nullptr);
+
+                    cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
+                }
+            } break;
+        default:
+            {
+                GGML_ABORT("unknown pooling type");
+            }
+    }
+
+    cb(cur, "result_embd_pooled", -1);
+    res->t_embd_pooled = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+

+ 576 - 0
src/llama-graph.h

@@ -0,0 +1,576 @@
+#pragma once
+
+#include "llama-arch.h"
+#include "llama-hparams.h"
+#include "llama-adapter.h"
+
+#include <cstdint>
+#include <vector>
+#include <memory>
+#include <set>
+#include <functional>
+
+struct ggml_cgraph;
+struct ggml_context;
+struct ggml_tensor;
+
+struct llama_ubatch;
+struct llama_cparams;
+
+class llama_memory_i;
+class llama_kv_cache_unified;
+
+// certain models (typically multi-modal) can produce different types of graphs
+enum llm_graph_type {
+    LLM_GRAPH_TYPE_DEFAULT,
+    LLM_GRAPH_TYPE_ENCODER,
+    LLM_GRAPH_TYPE_DECODER,
+};
+
+enum llm_ffn_op_type {
+    LLM_FFN_SILU,
+    LLM_FFN_GELU,
+    LLM_FFN_RELU,
+    LLM_FFN_RELU_SQR,
+    LLM_FFN_SWIGLU,
+};
+
+enum llm_ffn_gate_type {
+    LLM_FFN_SEQ,
+    LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
+};
+
+enum llm_norm_type {
+    LLM_NORM,
+    LLM_NORM_RMS,
+    LLM_NORM_GROUP,
+};
+
+// TODO: tmp - need something better to pass the data from the encoder to the decoder
+struct llama_cross {
+    // the output embeddings from the encoder as a ggml tensor
+    // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
+    //       ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
+    //ggml_tensor * t_embd = nullptr;
+
+    int64_t n_embd = 0;
+    int64_t n_enc  = 0;
+
+    // embeddings data copied to host memory (tmp)
+    std::vector<float> v_embd;
+
+    // needed to construct the cross-attention mask in the decoder
+    std::vector<std::set<llama_seq_id>> seq_ids_enc;
+};
+
+//
+// llm_graph_input
+//
+
+class llm_graph_input_i {
+public:
+    virtual ~llm_graph_input_i() = default;
+
+    virtual void set_input(const llama_ubatch * ubatch) = 0;
+};
+
+using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
+
+
+class llm_graph_input_embd : public llm_graph_input_i {
+public:
+    llm_graph_input_embd()          = default;
+    virtual ~llm_graph_input_embd() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * tokens = nullptr; // I32 [n_batch]
+    ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
+};
+
+class llm_graph_input_pos : public llm_graph_input_i {
+public:
+    llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
+    virtual ~llm_graph_input_pos() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * pos = nullptr; // I32 [n_batch]
+
+    const int64_t n_pos_per_token = 1;
+};
+
+class llm_graph_input_pos_bucket : public llm_graph_input_i {
+public:
+    llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
+    virtual ~llm_graph_input_pos_bucket() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
+
+    const llama_hparams & hparams;
+};
+
+class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
+public:
+    llm_graph_input_pos_bucket_kv(
+            const llama_hparams & hparams,
+            const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
+    virtual ~llm_graph_input_pos_bucket_kv() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
+
+    const llama_hparams & hparams;
+    const llama_kv_cache_unified * kv_self;
+};
+
+class llm_graph_input_out_ids : public llm_graph_input_i {
+public:
+    llm_graph_input_out_ids(
+            const llama_hparams & hparams,
+            const llama_cparams & cparams,
+            int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
+    virtual ~llm_graph_input_out_ids() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * out_ids; // I32 [n_outputs]
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+
+    const int32_t n_outputs;
+};
+
+class llm_graph_input_mean : public llm_graph_input_i {
+public:
+    llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
+    virtual ~llm_graph_input_mean() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * mean; // F32 [n_batch, n_batch]
+
+    const llama_cparams & cparams;
+};
+
+class llm_graph_input_cls : public llm_graph_input_i {
+public:
+    llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
+    virtual ~llm_graph_input_cls() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * cls; // I32 [n_batch]
+
+    const llama_cparams & cparams;
+};
+
+class llm_graph_input_s_copy : public llm_graph_input_i {
+public:
+    llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
+    virtual ~llm_graph_input_s_copy() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * s_copy; // I32 [kv_size]
+
+    const llama_kv_cache_unified * kv_self;
+};
+
+class llm_graph_input_s_mask : public llm_graph_input_i {
+public:
+    llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
+    virtual ~llm_graph_input_s_mask() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * s_mask; // F32 [1, n_kv]
+
+    const llama_kv_cache_unified * kv_self;
+};
+
+class llm_graph_input_cross_embd : public llm_graph_input_i {
+public:
+    llm_graph_input_cross_embd(
+            const llama_cross * cross) : cross(cross) {}
+    virtual ~llm_graph_input_cross_embd() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
+
+    const llama_cross * cross;
+};
+
+class llm_graph_input_attn_no_cache : public llm_graph_input_i {
+public:
+    llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
+        hparams(hparams),
+        cparams(cparams) {
+    }
+    ~llm_graph_input_attn_no_cache() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
+
+    ggml_tensor * kq_mask     = nullptr; // F32 [n_tokens, n_batch]
+    ggml_tensor * kq_mask_cnv = nullptr; //     [n_tokens, n_batch]
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+};
+
+class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
+public:
+    llm_graph_input_attn_kv_unified(
+            const llama_hparams & hparams,
+            const llama_cparams & cparams,
+            const llama_kv_cache_unified * kv_self) :
+        hparams(hparams),
+        cparams(cparams),
+        kv_self(kv_self) {
+    }
+    ~llm_graph_input_attn_kv_unified() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
+    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
+
+    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch]
+    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch]
+    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch]
+    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch]
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+
+    const llama_kv_cache_unified * kv_self;
+};
+
+class llm_graph_input_attn_cross : public llm_graph_input_i {
+public:
+    llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
+    ~llm_graph_input_attn_cross() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
+
+    ggml_tensor * cross_kq_mask     = nullptr; // F32 [n_outputs_enc, n_batch]
+    ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
+
+    const llama_cross * cross = nullptr;
+};
+
+//
+// llm_graph_result
+//
+
+// these objects deliver the result from the graph build process back to the llama_context
+// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
+//   specific data, by calling the set_inputs() method
+// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
+//   these are used by the llama_context to extact the relevant data, based on the compute parameters
+
+class llm_graph_result_i {
+public:
+    virtual ~llm_graph_result_i() = default;
+
+    virtual ggml_tensor * get_logits()      = 0;
+    virtual ggml_tensor * get_embd()        = 0;
+    virtual ggml_tensor * get_embd_pooled() = 0;
+
+    virtual void set_inputs(const llama_ubatch * ubatch) = 0;
+};
+
+using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
+
+
+class llm_graph_result : public llm_graph_result_i {
+public:
+    virtual ~llm_graph_result() = default;
+
+    ggml_tensor * get_logits()      override { return t_logits; }
+    ggml_tensor * get_embd()        override { return t_embd; }
+    ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
+
+    void set_inputs(const llama_ubatch * ubatch) override {
+        for (auto & input : inputs) {
+            input->set_input(ubatch);
+        }
+    }
+
+    llm_graph_input_i * add_input(llm_graph_input_ptr input) {
+        inputs.emplace_back(std::move(input));
+        return inputs.back().get();
+    }
+
+    // important graph nodes
+    ggml_tensor * t_logits      = nullptr;
+    ggml_tensor * t_embd        = nullptr;
+    ggml_tensor * t_embd_pooled = nullptr;
+
+    std::vector<llm_graph_input_ptr> inputs;
+};
+
+//
+// llm_graph_context
+//
+
+// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
+using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
+
+struct llm_graph_params {
+    ggml_context * ctx;
+
+    const llm_arch arch;
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+    const llama_ubatch  & ubatch;
+
+    ggml_backend_sched * sched;
+    ggml_backend * backend_cpu;
+
+    const llama_adapter_cvec  * cvec;
+    const llama_adapter_loras * loras;
+    const llama_memory_i      * memory;
+    const llama_cross         * cross;
+
+    int32_t n_outputs;
+
+    const llm_graph_cb & cb;
+};
+
+struct llm_graph_context {
+    const llm_arch arch;
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+    const llama_ubatch  & ubatch;
+
+    const int64_t n_embd;
+    const int64_t n_layer;
+    const int64_t n_rot;
+    const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
+    const int64_t n_ctx_per_seq;
+    const int64_t n_head;
+    const int64_t n_head_kv;
+    const int64_t n_embd_head_k;
+    const int64_t n_embd_k_gqa;
+    const int64_t n_embd_head_v;
+    const int64_t n_embd_v_gqa;
+    const int64_t n_expert;
+    const int64_t n_expert_used;
+
+    const float freq_base;
+    const float freq_scale;
+    const float ext_factor;
+    const float attn_factor;
+    const float beta_fast;
+    const float beta_slow;
+    const float norm_eps;
+    const float norm_rms_eps;
+
+    const int32_t n_tokens;
+    const int32_t n_outputs;
+    const int32_t n_ctx_orig; // yarn
+
+    const enum llama_pooling_type pooling_type;
+    const enum llama_rope_type    rope_type;
+
+    ggml_context * ctx0 = nullptr;
+
+    ggml_backend_sched * sched;
+
+    ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
+
+    const llama_adapter_cvec  * cvec;
+    const llama_adapter_loras * loras;
+    const llama_memory_i      * memory;
+    const llama_cross         * cross;
+
+    const llm_graph_cb & cb_func;
+
+    std::unique_ptr<llm_graph_result> res;
+
+    llm_graph_context(const llm_graph_params & params);
+
+    int64_t n_pos_per_token() const;
+
+    void cb(ggml_tensor * cur, const char * name, int il) const;
+
+    //
+    // common
+    //
+
+    ggml_tensor * build_cvec(
+             ggml_tensor * cur,
+                     int   il) const;
+
+    // do mat_mul, while optionally apply lora
+    ggml_tensor * build_lora_mm(
+              ggml_tensor * w,
+              ggml_tensor * cur) const;
+
+    // do mat_mul_id, while optionally apply lora
+    ggml_tensor * build_lora_mm_id(
+              ggml_tensor * w,   // ggml_tensor * as
+              ggml_tensor * cur, // ggml_tensor * b
+              ggml_tensor * ids) const;
+
+    ggml_tensor * build_norm(
+             ggml_tensor * cur,
+             ggml_tensor * mw,
+             ggml_tensor * mb,
+           llm_norm_type   type,
+                     int   il) const;
+
+    ggml_tensor * build_ffn(
+             ggml_tensor * cur,
+             ggml_tensor * up,
+             ggml_tensor * up_b,
+             ggml_tensor * up_s,
+             ggml_tensor * gate,
+             ggml_tensor * gate_b,
+             ggml_tensor * gate_s,
+             ggml_tensor * down,
+             ggml_tensor * down_b,
+             ggml_tensor * down_s,
+             ggml_tensor * act_scales,
+         llm_ffn_op_type   type_op,
+       llm_ffn_gate_type   type_gate,
+                     int   il) const;
+
+    ggml_tensor * build_moe_ffn(
+             ggml_tensor * cur,
+             ggml_tensor * gate_inp,
+             ggml_tensor * up_exps,
+             ggml_tensor * gate_exps,
+             ggml_tensor * down_exps,
+             ggml_tensor * exp_probs_b,
+                 int64_t   n_expert,
+                 int64_t   n_expert_used,
+         llm_ffn_op_type   type_op,
+                    bool   norm_w,
+                    bool   scale_w,
+                   float   w_scale,
+            llama_expert_gating_func_type gating_op,
+                     int   il) const;
+
+    //
+    // inputs
+    //
+
+    ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
+    ggml_tensor * build_inp_pos() const;
+    ggml_tensor * build_inp_out_ids() const;
+    ggml_tensor * build_inp_mean() const;
+    ggml_tensor * build_inp_cls() const;
+    ggml_tensor * build_inp_s_copy() const;
+    ggml_tensor * build_inp_s_mask() const;
+
+    ggml_tensor * build_inp_cross_embd() const;
+    ggml_tensor * build_inp_pos_bucket_enc() const;
+    ggml_tensor * build_inp_pos_bucket_dec() const;
+    ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
+
+    //
+    // attention
+    //
+
+    ggml_tensor * build_attn_mha(
+             ggml_cgraph * gf,
+             ggml_tensor * q,
+             ggml_tensor * k,
+             ggml_tensor * v,
+             ggml_tensor * kq_b,
+             ggml_tensor * kq_mask,
+                    bool   v_trans,
+                   float   kq_scale) const;
+
+    llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
+
+    ggml_tensor * build_attn(
+            llm_graph_input_attn_no_cache * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur,
+            ggml_tensor * k_cur,
+            ggml_tensor * v_cur,
+            ggml_tensor * kq_b,
+                  float   kq_scale,
+                    int   il) const;
+
+    llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(
+            bool causal,
+            bool swa) const;
+
+    ggml_tensor * build_attn(
+            llm_graph_input_attn_kv_unified * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur,
+            ggml_tensor * k_cur,
+            ggml_tensor * v_cur,
+            ggml_tensor * kq_b,
+                  float   kq_scale,
+                    int   il) const;
+
+    llm_graph_input_attn_cross * build_attn_inp_cross() const;
+
+    ggml_tensor * build_attn(
+            llm_graph_input_attn_cross * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur,
+            ggml_tensor * k_cur,
+            ggml_tensor * v_cur,
+            ggml_tensor * kq_b,
+                  float   kq_scale,
+                    int   il) const;
+
+    //
+    // recurrent
+    //
+
+    ggml_tensor * build_copy_mask_state(
+             ggml_cgraph * gf,
+             ggml_tensor * s,
+             ggml_tensor * state_copy,
+             ggml_tensor * state_mask,
+                 int32_t   n_state,
+                 int32_t   n_seqs) const;
+
+    ggml_tensor * build_rwkv_token_shift_load(
+             ggml_cgraph * gf,
+             ggml_tensor * state_copy,
+             ggml_tensor * state_mask,
+      const llama_ubatch & ubatch,
+                     int   il) const;
+
+    ggml_tensor * build_rwkv_token_shift_store(
+             ggml_tensor * token_shift,
+      const llama_ubatch & ubatch,
+                     int   il) const;
+
+    //
+    // pooling
+    //
+
+    void build_pooling(
+            ggml_cgraph * gf,
+            ggml_tensor * cls,
+            ggml_tensor * cls_b,
+            ggml_tensor * cls_out,
+            ggml_tensor * cls_out_b) const;
+};

+ 15 - 0
src/llama-io.cpp

@@ -0,0 +1,15 @@
+#include "llama-io.h"
+
+void llama_io_write_i::write_string(const std::string & str) {
+    uint32_t str_size = str.size();
+
+    write(&str_size,  sizeof(str_size));
+    write(str.data(), str_size);
+}
+
+void llama_io_read_i::read_string(std::string & str) {
+    uint32_t str_size;
+    read_to(&str_size, sizeof(str_size));
+
+    str.assign((const char *) read(str_size), str_size);
+}

+ 35 - 0
src/llama-io.h

@@ -0,0 +1,35 @@
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+#include <string>
+
+struct ggml_tensor;
+
+class llama_io_write_i {
+public:
+    llama_io_write_i() = default;
+    virtual ~llama_io_write_i() = default;
+
+    virtual void write(const void * src, size_t size) = 0;
+    virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0;
+
+    // bytes written so far
+    virtual size_t n_bytes() = 0;
+
+    void write_string(const std::string & str);
+};
+
+class llama_io_read_i {
+public:
+    llama_io_read_i() = default;
+    virtual ~llama_io_read_i() = default;
+
+    virtual const uint8_t * read(size_t size) = 0;
+    virtual void read_to(void * dst, size_t size) = 0;
+
+    // bytes read so far
+    virtual size_t n_bytes() = 0;
+
+    void read_string(std::string & str);
+};

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 989 - 281
src/llama-kv-cache.cpp


+ 178 - 110
src/llama-kv-cache.h

@@ -1,12 +1,29 @@
 #pragma once
 
 #include "llama.h"
+#include "llama-io.h"
+#include "llama-memory.h"
 
 #include "ggml-cpp.h"
 
+#include <functional>
 #include <set>
 #include <vector>
-#include <algorithm>
+
+struct llama_cparams;
+struct llama_hparams;
+struct llama_ubatch;
+
+struct llama_kv_cache : public llama_memory_i {
+    using llama_memory_i::llama_memory_i;
+
+    virtual int32_t  get_n_tokens()   const = 0;
+    virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
+
+    virtual bool get_can_shift() const = 0;
+
+    bool get_can_edit() const override { return get_can_shift(); }
+};
 
 struct llama_kv_cell {
     llama_pos pos   = -1;
@@ -29,55 +46,6 @@ struct llama_kv_cell {
     }
 };
 
-// ring-buffer of cached KV data
-struct llama_kv_cache {
-    bool has_shift = false;
-    bool do_defrag = false;
-    bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
-    bool v_trans   = true;  // the value tensor is transposed
-    bool can_shift = false;
-
-    // Note: The value of head isn't only used to optimize searching
-    // for a free KV slot. llama_decode_impl also uses it, so it
-    // cannot be freely changed after a slot has been allocated.
-    uint32_t head = 0;
-    uint32_t size = 0;
-    uint32_t used = 0; // used cells (i.e. at least one seq_id)
-
-    // computed before each graph build
-    uint32_t n = 0;
-
-    ggml_type type_k = GGML_TYPE_F16;
-    ggml_type type_v = GGML_TYPE_F16;
-
-    std::vector<llama_kv_cell> cells;
-
-    std::vector<struct ggml_tensor *> k_l; // per layer
-    std::vector<struct ggml_tensor *> v_l;
-
-    std::vector<ggml_context_ptr> ctxs;
-    std::vector<ggml_backend_buffer_ptr> bufs;
-
-    size_t total_size() const {
-        size_t size = 0;
-        for (const auto & buf : bufs) {
-            size += ggml_backend_buffer_get_size(buf.get());
-        }
-
-        return size;
-    }
-
-    // TODO: better data structures to reduce the cost of this operation
-    llama_pos max_pos() const {
-        llama_pos max_pos = -1;
-        for (const auto & cell : cells) {
-            max_pos = std::max(max_pos, cell.pos);
-        }
-
-        return max_pos;
-    }
-};
-
 // a structure holds information about the slot found in llama_kv_cache_find_slot
 struct llama_kv_cache_slot_info {
     std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
@@ -89,82 +57,131 @@ struct llama_kv_cache_slot_info {
     operator bool() const { return found; }
 };
 
-// TODO: maybe not needed
-uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
-
-bool llama_kv_cache_init(
-        struct llama_kv_cache & cache,
-            const llama_model & model,
+// ring-buffer of cached KV data
+// TODO: pimpl
+// TODO: add notion of max sequences
+class llama_kv_cache_unified : public llama_kv_cache {
+public:
+    // can be used to query data from the model if needed
+    struct callbacks {
+        std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
+    };
+
+    llama_kv_cache_unified(
+            const llama_hparams & hparams,
+            callbacks             cbs);
+
+    virtual ~llama_kv_cache_unified() = default;
+
+    // TODO: become constructor
+    bool init(
+            const llama_model & model,   // TODO: do not reference the model
           const llama_cparams & cparams,
                     ggml_type   type_k,
                     ggml_type   type_v,
                      uint32_t   kv_size,
                          bool   offload);
 
-// find an empty slot of size "n_tokens" in the cache
-// updates the cache head
-// returns a structure holding information about the slot found
-// Note: On success, it's important that cache.head points
-// to the first cell of the slot.
-struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
-           struct llama_kv_cache & cache,
-       const struct llama_ubatch & batch);
+    int32_t  get_n_tokens()   const override;
+    uint32_t get_used_cells() const override;
 
-// find how many cells are currently in use
-uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
+    size_t total_size() const;
 
-void llama_kv_cache_clear(struct llama_kv_cache & cache);
+    // TODO: better data structures to reduce the cost of this operation
+    llama_pos pos_max() const;
 
-bool llama_kv_cache_seq_rm(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1);
+    void clear() override;
+    void defrag() override;
 
-void llama_kv_cache_seq_cp(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id_src,
-                 llama_seq_id   seq_id_dst,
-                    llama_pos   p0,
-                    llama_pos   p1);
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id) override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 
-void llama_kv_cache_seq_keep(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id);
+    llama_pos seq_pos_max(llama_seq_id seq_id) override;
 
-void llama_kv_cache_seq_add(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1,
-                    llama_pos   delta);
+    bool get_can_shift() const override;
 
-void llama_kv_cache_seq_div(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1,
-                          int   d);
+    // find an empty slot of size "n_tokens" in the cache
+    // updates the cache head
+    // returns a structure holding information about the slot found
+    // Note: On success, it's important that cache.head points
+    // to the first cell of the slot.
+    llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
 
-llama_pos llama_kv_cache_seq_pos_max(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id);
+    // TODO: maybe not needed
+    uint32_t get_padding(const llama_cparams & cparams) const;
 
-void llama_kv_cache_defrag(struct llama_kv_cache & cache);
+    // find how many cells are currently in use
+    uint32_t cell_max() const;
 
-int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv);
+    size_t size_k_bytes() const;
+    size_t size_v_bytes() const;
 
-int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv);
+    // defrag
 
-bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv);
+    struct {
+        std::vector<uint32_t> ids;
+    } defrag_info;
 
-//
-// kv cache view
-//
+    // return true if cells have been moved
+    bool defrag_prepare(int32_t n_max_nodes);
+
+    // state save/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1);
 
-struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
+    // members
 
-void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
+    const llama_hparams & hparams;
+
+    callbacks cbs;
+
+    bool has_shift = false;
+    bool do_defrag = false;
+
+    // TODO: remove this and implement llama_kv_cache_recurrent instead
+    bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
+
+    bool v_trans   = true;  // the value tensor is transposed
+    bool can_shift = false;
+
+    // Note: The value of head isn't only used to optimize searching
+    // for a free KV slot. llama_decode_impl also uses it, so it
+    // cannot be freely changed after a slot has been allocated.
+    uint32_t head = 0;
+    uint32_t size = 0;
+    uint32_t used = 0; // used cells (i.e. at least one seq_id)
+
+    // computed before each graph build
+    uint32_t n = 0;
+
+    std::vector<llama_kv_cell> cells;
+
+    std::vector<ggml_tensor *> k_l; // per layer
+    std::vector<ggml_tensor *> v_l;
+
+private:
+    ggml_type type_k = GGML_TYPE_F16;
+    ggml_type type_v = GGML_TYPE_F16;
+
+    std::vector<ggml_context_ptr>        ctxs;
+    std::vector<ggml_backend_buffer_ptr> bufs;
+
+    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
+    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
+
+    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
+    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
+};
+
+// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
+//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
+//public:
+//    using llama_kv_cache_unified::llama_kv_cache_unified;
+//};
 
 //
 // kv cache restore
@@ -184,13 +201,15 @@ struct llama_kv_slot_restorer {
 
     bool do_restore = false;
 
-    explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
+    llama_kv_cache_unified & cache;
+
+    explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
         old_state.head = cache.head;
         old_state.n    = cache.n;
     }
 
     // saves a slot information for future restoration
-    void save(const struct llama_kv_cache_slot_info & slot) {
+    void save(const llama_kv_cache_slot_info & slot) {
         if (slot) {
             do_restore = true;
             if (slot.boundaries.first != slot.boundaries.second) {
@@ -201,19 +220,68 @@ struct llama_kv_slot_restorer {
 
     // must be explicitly called to restore the kv_cache state
     // and rollback changes from all llama_kv_cache_find_slot calls
-    void restore(struct llama_kv_cache & cache) {
+    void restore() {
         if (do_restore) {
             cache.head = old_state.head;
             cache.n    = old_state.n;
 
             if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
-                llama_kv_cache_seq_rm(cache, -1, -1, -1);
+                cache.seq_rm(-1, -1, -1);
             } else {
                 for (auto & slot : slot_boundaries) {
-                    llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
+                    cache.seq_rm(-1, slot.first, slot.second);
                 }
             }
         }
     }
 };
 
+// TODO: maybe become part of the public llama_kv_cache in the future
+int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
+
+int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
+
+void llama_kv_cache_clear(llama_kv_cache * kv);
+
+bool llama_kv_cache_seq_rm(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id,
+             llama_pos   p0,
+             llama_pos   p1);
+
+void llama_kv_cache_seq_cp(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id_src,
+          llama_seq_id   seq_id_dst,
+             llama_pos   p0,
+             llama_pos   p1);
+
+void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
+
+void llama_kv_cache_seq_add(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id,
+             llama_pos   p0,
+             llama_pos   p1,
+             llama_pos   delta);
+
+void llama_kv_cache_seq_div(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id,
+             llama_pos   p0,
+             llama_pos   p1,
+                   int   d);
+
+llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
+
+void llama_kv_cache_defrag(llama_kv_cache * kv);
+
+bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
+
+//
+// kv cache view
+//
+
+llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
+
+void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);

+ 1 - 0
src/llama-memory.cpp

@@ -0,0 +1 @@
+#include "llama-memory.h"

+ 21 - 0
src/llama-memory.h

@@ -0,0 +1,21 @@
+#pragma once
+
+#include "llama.h"
+
+// general concept of LLM memory
+// the KV cache is a type of LLM memory, but there can be other types
+class llama_memory_i {
+public:
+    virtual void clear() = 0;
+    virtual void defrag() = 0;
+
+    virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0;
+    virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
+    virtual void seq_keep(llama_seq_id seq_id) = 0;
+    virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) = 0;
+    virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0;
+
+    virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
+
+    virtual bool get_can_edit() const = 0;
+};

+ 7529 - 219
src/llama-model.cpp

@@ -2,12 +2,17 @@
 
 #include "llama-impl.h"
 #include "llama-mmap.h"
+#include "llama-batch.h"
+#include "llama-cparams.h"
 #include "llama-model-loader.h"
+#include "llama-kv-cache.h"
 
 #include "ggml-cpp.h"
 
 #include <algorithm>
 #include <cassert>
+#include <cmath>
+#include <cfloat>
 #include <cstring>
 #include <cmath>
 #include <functional>
@@ -244,6 +249,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara
             return cur_buft;
         }
     }
+
     return nullptr;
 }
 
@@ -302,7 +308,7 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
 }
 
 // GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU
-static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) {
+static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode split_mode, const float * tensor_split) {
     buft_list_t buft_list;
 
     // add the device split buffer type if requested and available
@@ -369,7 +375,7 @@ struct llama_model::impl {
     std::vector<layer_dev> dev_layer;
 };
 
-llama_model::llama_model(const struct llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) {
+llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) {
 }
 
 llama_model::~llama_model() {}
@@ -391,7 +397,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
 
     // get metadata as string
     for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
-        enum gguf_type type = gguf_get_kv_type(ctx, i);
+        gguf_type type = gguf_get_kv_type(ctx, i);
         if (type == GGUF_TYPE_ARRAY) {
             continue;
         }
@@ -1444,7 +1450,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
             // skip unused tensors
             if (info.op == GGML_OP_NONE) {
-                LLAMA_LOG_WARN("model has unused tensor %s -- ignoring\n", tn.str().c_str());
+                const size_t nbytes = ggml_nbytes(t_meta);
+                LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes);
+
+                ml.size_data -= nbytes;
                 ml.n_created++;
 
                 return nullptr;
@@ -3631,8 +3640,8 @@ size_t llama_model::size() const {
     return pimpl->n_bytes;
 }
 
-size_t llama_model::max_nodes() const {
-    return std::max<size_t>(8192, tensors_by_name.size()*5);
+size_t llama_model::n_tensors() const {
+    return tensors_by_name.size();
 }
 
 size_t llama_model::n_devices() const {
@@ -3745,7 +3754,7 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
         LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
         LLAMA_LOG_INFO("%s: expert_weights_norm  = %d\n",     __func__, hparams.expert_weights_norm);
-        LLAMA_LOG_INFO("%s: expert_gating_func   = %s\n",     __func__, llama_expert_gating_func_name((enum llama_expert_gating_func_type) hparams.expert_gating_func));
+        LLAMA_LOG_INFO("%s: expert_gating_func   = %s\n",     __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
         LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
     }
 
@@ -3821,9 +3830,9 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
             });
 }
 
-const struct ggml_tensor * llama_model::get_tensor(const char * name) const {
+const ggml_tensor * llama_model::get_tensor(const char * name) const {
     auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(),
-            [name](const std::pair<std::string, struct ggml_tensor *> & it) {
+            [name](const std::pair<std::string, ggml_tensor *> & it) {
                 return it.first == name;
             });
     if (it == tensors_by_name.end()) {
@@ -3833,255 +3842,7556 @@ const struct ggml_tensor * llama_model::get_tensor(const char * name) const {
     return it->second;
 }
 
-//
-// interface implementation
-//
+struct llm_build_llama : public llm_graph_context {
+    llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
 
-struct llama_model_params llama_model_default_params() {
-    struct llama_model_params result = {
-        /*.devices                     =*/ nullptr,
-        /*.n_gpu_layers                =*/ 0,
-        /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
-        /*.main_gpu                    =*/ 0,
-        /*.tensor_split                =*/ nullptr,
-        /*.progress_callback           =*/ nullptr,
-        /*.progress_callback_user_data =*/ nullptr,
-        /*.kv_overrides                =*/ nullptr,
-        /*.vocab_only                  =*/ false,
-        /*.use_mmap                    =*/ true,
-        /*.use_mlock                   =*/ false,
-        /*.check_tensors               =*/ false,
-    };
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-#ifdef GGML_USE_METAL
-    // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
-    result.n_gpu_layers = 999;
-#endif
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
 
-    return result;
-}
+        inpL = build_inp_embd(model.tok_embd);
 
-const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model) {
-    return &model->vocab;
-}
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
 
-void llama_free_model(struct llama_model * model) {
-    llama_model_free(model);
-}
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
 
-void llama_model_free(struct llama_model * model) {
-    delete model;
-}
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
 
-int32_t llama_model_n_ctx_train(const struct llama_model * model) {
-    return model->hparams.n_ctx_train;
-}
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
 
-int32_t llama_model_n_embd(const struct llama_model * model) {
-    return model->hparams.n_embd;
-}
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
 
-int32_t llama_model_n_layer(const struct llama_model * model) {
-    return model->hparams.n_layer;
-}
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
 
-int32_t llama_model_n_head(const struct llama_model * model) {
-    return model->hparams.n_head();
-}
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
 
-int32_t llama_model_n_head_kv(const struct llama_model * model) {
-    return model->hparams.n_head_kv();
-}
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+            }
 
-// deprecated
-int32_t llama_n_ctx_train(const struct llama_model * model) {
-    return llama_model_n_ctx_train(model);
-}
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
 
-// deprecated
-int32_t llama_n_embd(const struct llama_model * model) {
-    return llama_model_n_embd(model);
-}
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
 
-// deprecated
-int32_t llama_n_layer(const struct llama_model * model) {
-    return llama_model_n_layer(model);
-}
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
 
-// deprecated
-int32_t llama_n_head(const struct llama_model * model) {
-    return llama_model_n_head(model);
-}
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
 
-enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
-    switch (model->arch) {
-        // these models do not use RoPE
-        case LLM_ARCH_GPT2:
-        case LLM_ARCH_GPTJ:
-        case LLM_ARCH_MPT:
-        case LLM_ARCH_REFACT:
-        case LLM_ARCH_BLOOM:
-        case LLM_ARCH_MAMBA:
-        case LLM_ARCH_JINA_BERT_V2:
-        case LLM_ARCH_T5:
-        case LLM_ARCH_T5ENCODER:
-        case LLM_ARCH_JAIS:
-        case LLM_ARCH_RWKV6:
-        case LLM_ARCH_RWKV6QWEN2:
-        case LLM_ARCH_WAVTOKENIZER_DEC:
-            return LLAMA_ROPE_TYPE_NONE;
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
 
-        // use what we call a normal RoPE, operating on pairs of consecutive head values
-        case LLM_ARCH_LLAMA:
-        case LLM_ARCH_DECI:
-        case LLM_ARCH_BAICHUAN:
-        case LLM_ARCH_STARCODER:
-        case LLM_ARCH_PLAMO:
-        case LLM_ARCH_ORION:
-        case LLM_ARCH_INTERNLM2:
-        case LLM_ARCH_MINICPM:
-        case LLM_ARCH_XVERSE:
-        case LLM_ARCH_COMMAND_R:
-        case LLM_ARCH_COHERE2:
-        case LLM_ARCH_OLMO:
-        case LLM_ARCH_ARCTIC:
-        case LLM_ARCH_DEEPSEEK:
-        case LLM_ARCH_DEEPSEEK2:
-        case LLM_ARCH_CHATGLM:
-        case LLM_ARCH_GRANITE:
-        case LLM_ARCH_GRANITE_MOE:
-        case LLM_ARCH_CHAMELEON:
-            return LLAMA_ROPE_TYPE_NORM;
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il);
+                cb(cur, "ffn_moe_out", il);
+            }
 
-        // the pairs of head values are offset by n_rot/2
-        case LLM_ARCH_FALCON:
-        case LLM_ARCH_GROK:
-        case LLM_ARCH_DBRX:
-        case LLM_ARCH_BERT:
-        case LLM_ARCH_NOMIC_BERT:
-        case LLM_ARCH_STABLELM:
-        case LLM_ARCH_BITNET:
-        case LLM_ARCH_QWEN:
-        case LLM_ARCH_QWEN2:
-        case LLM_ARCH_QWEN2MOE:
-        case LLM_ARCH_OLMO2:
-        case LLM_ARCH_OLMOE:
-        case LLM_ARCH_PHI2:
-        case LLM_ARCH_PHI3:
-        case LLM_ARCH_PHIMOE:
-        case LLM_ARCH_GEMMA:
-        case LLM_ARCH_GEMMA2:
-        case LLM_ARCH_GEMMA3:
-        case LLM_ARCH_STARCODER2:
-        case LLM_ARCH_OPENELM:
-        case LLM_ARCH_GPTNEOX:
-        case LLM_ARCH_CODESHELL:
-        case LLM_ARCH_NEMOTRON:
-        case LLM_ARCH_EXAONE:
-        case LLM_ARCH_MINICPM3:
-            return LLAMA_ROPE_TYPE_NEOX;
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
 
-        case LLM_ARCH_QWEN2VL:
-            return LLAMA_ROPE_TYPE_MROPE;
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
 
-        // all model arches should be listed explicitly here
-        case LLM_ARCH_UNKNOWN:
-            GGML_ABORT("unknown architecture");
-    }
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
 
-    return LLAMA_ROPE_TYPE_NONE;
-}
+            // input for next layer
+            inpL = cur;
+        }
 
-float llama_model_rope_freq_scale_train(const struct llama_model * model) {
-    return model->hparams.rope_freq_scale_train;
-}
+        cur = inpL;
 
-int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) {
-    const auto & it = model->gguf_kv.find(key);
-    if (it == model->gguf_kv.end()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        // For Granite architecture
+        if (hparams.f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
         }
-        return -1;
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
     }
-    return snprintf(buf, buf_size, "%s", it->second.c_str());
-}
+};
 
-int32_t llama_model_meta_count(const struct llama_model * model) {
-    return (int)model->gguf_kv.size();
-}
+struct llm_build_deci : public llm_graph_context {
+    llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
 
-int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) {
-    if (i < 0 || i >= (int)model->gguf_kv.size()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_head    = hparams.n_head(il);
+
+            if (n_head == 0) {
+                // attention-free layer of Llama-3_1-Nemotron-51B
+                cur = inpL;
+            } else {
+                // norm
+                cur = build_norm(inpL,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "attn_norm", il);
+            }
+
+            if (n_head > 0 && n_head_kv == 0) {
+                // "linear attention" of Llama-3_1-Nemotron-51B
+                cur = build_lora_mm(model.layers[il].wo, cur);
+                cb(cur, "wo", il);
+            } else if (n_head > 0) {
+                // self-attention
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
+            // modified to support attention-free layer of Llama-3_1-Nemotron-51B
+            ggml_tensor * ffn_inp = cur;
+            if (n_head > 0) {
+                ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+            }
+
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
-        return -1;
-    }
-    auto it = model->gguf_kv.begin();
-    std::advance(it, i);
-    return snprintf(buf, buf_size, "%s", it->first.c_str());
-}
 
-int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size) {
-    if (i < 0 || i >= (int)model->gguf_kv.size()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        // For Granite architecture
+        if (hparams.f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
         }
-        return -1;
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
     }
-    auto it = model->gguf_kv.begin();
-    std::advance(it, i);
-    return snprintf(buf, buf_size, "%s", it->second.c_str());
-}
+};
 
-int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
-    return snprintf(buf, buf_size, "%s", model->desc().c_str());
-}
+struct llm_build_baichuan : public llm_graph_context {
+    llm_build_baichuan(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
 
-uint64_t llama_model_size(const struct llama_model * model) {
-    return model->size();
-}
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-const char * llama_model_chat_template(const struct llama_model * model, const char * name) {
-    const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
-        : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
-    const auto & it = model->gguf_kv.find(key);
-    if (it == model->gguf_kv.end()) {
-        return nullptr;
-    }
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
 
-    return it->second.c_str();
-}
+        inpL = build_inp_embd(model.tok_embd);
 
-uint64_t llama_model_n_params(const struct llama_model * model) {
-    return model->n_elements();
-}
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
 
-bool llama_model_has_encoder(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_T5:        return true;
-        case LLM_ARCH_T5ENCODER: return true;
-        default:                 return false;
-    }
-}
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
 
-bool llama_model_has_decoder(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_T5ENCODER: return false;
-        default:                 return true;
-    }
-}
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
 
-llama_token llama_model_decoder_start_token(const struct llama_model * model) {
-    return model->hparams.dec_start_token_id;
-}
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
 
-bool llama_model_is_recurrent(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_MAMBA:  return true;
-        case LLM_ARCH_RWKV6:  return true;
-        case LLM_ARCH_RWKV6QWEN2: return true;
-        default:              return false;
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                switch (model.type) {
+                    case LLM_TYPE_7B:
+                        Qcur = ggml_rope_ext(
+                                ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                ext_factor, attn_factor, beta_fast, beta_slow
+                                );
+                        Kcur = ggml_rope_ext(
+                                ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                ext_factor, attn_factor, beta_fast, beta_slow
+                                );
+                        break;
+                    case LLM_TYPE_13B:
+                        Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
+                        Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
+                        break;
+                    default:
+                        GGML_ABORT("fatal error");
+                }
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
     }
+};
+
+struct llm_build_xverse : public llm_graph_context {
+    llm_build_xverse(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_falcon : public llm_graph_context {
+    llm_build_falcon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * attn_norm;
+
+            attn_norm = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(attn_norm, "attn_norm", il);
+
+            // self-attention
+            {
+                if (model.layers[il].attn_norm_2) {
+                    // Falcon-40B
+                    cur = build_norm(inpL,
+                            model.layers[il].attn_norm_2,
+                            model.layers[il].attn_norm_2_b,
+                            LLM_NORM, il);
+                    cb(cur, "attn_norm_2", il);
+                } else {
+                    cur = attn_norm;
+                }
+
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                // using mode = 2 for neox mode
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur       = ggml_get_rows(ctx0,       cur, inp_out_ids);
+                inpL      = ggml_get_rows(ctx0,      inpL, inp_out_ids);
+                attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = cur;
+
+            // feed forward
+            {
+                cur = build_ffn(attn_norm, // !! use the attn norm, not the result
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = ggml_add(ctx0, cur, inpL);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        // norm
+        cur = build_norm(cur,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_grok : public llm_graph_context {
+    llm_build_grok(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // multiply by embedding_multiplier_scale of 78.38367176906169
+        inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // Grok
+            // if attn_out_norm is present then apply it before adding the input
+            if (model.layers[il].attn_out_norm) {
+                cur = build_norm(cur,
+                        model.layers[il].attn_out_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "attn_out_norm", il);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_GELU, true,
+                    false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    il);
+            cb(cur, "ffn_moe_out", il);
+
+            // Grok
+            // if layer_out_norm is present then apply it before adding the input
+            // Idea: maybe ffn_out_norm is a better name
+            if (model.layers[il].layer_out_norm) {
+                cur = build_norm(cur,
+                        model.layers[il].layer_out_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "layer_out_norm", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        // Grok
+        // multiply logits by output_multiplier_scale of 0.5773502691896257
+
+        cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_dbrx : public llm_graph_context {
+    llm_build_dbrx(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = nullptr;
+                ggml_tensor * Kcur = nullptr;
+                ggml_tensor * Vcur = nullptr;
+
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                cb(cur, "wqkv_clamped", il);
+
+                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].attn_out_norm, NULL,
+                    LLM_NORM, il);
+            cb(cur, "attn_out_norm", il);
+
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, true,
+                    false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    il);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_starcoder : public llm_graph_context {
+    llm_build_starcoder(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
+        cb(pos, "pos_embd", -1);
+
+        inpL = ggml_add(ctx0, inpL, pos);
+        cb(inpL, "inpL", -1);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_refact : public llm_graph_context {
+    llm_build_refact(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                cb(Kcur, "Kcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                cb(Qcur, "Qcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_bert : public llm_graph_context {
+    llm_build_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        ggml_tensor * inp_pos = nullptr;
+
+        if (model.arch != LLM_ARCH_JINA_BERT_V2) {
+            inp_pos = build_inp_pos();
+        }
+
+        // construct input embeddings (token, type, position)
+        inpL = build_inp_embd(model.tok_embd);
+
+        // token types are hardcoded to zero ("Sentence A")
+        ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
+        inpL = ggml_add(ctx0, inpL, type_row0);
+        if (model.arch == LLM_ARCH_BERT) {
+            inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
+        }
+        cb(inpL, "inp_embd", -1);
+
+        // embed layer norm
+        inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
+        cb(inpL, "inp_norm", -1);
+
+        auto * inp_attn = build_attn_inp_no_cache();
+
+        // iterate layers
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * cur = inpL;
+
+            ggml_tensor * Qcur;
+            ggml_tensor * Kcur;
+            ggml_tensor * Vcur;
+
+            // self-attention
+            if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
+                Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
+
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = build_norm(Qcur,
+                            model.layers[il].attn_q_norm,
+                            model.layers[il].attn_q_norm_b,
+                            LLM_NORM, il);
+                }
+
+                Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
+
+                if (model.layers[il].attn_k_norm) {
+                    Kcur = build_norm(Kcur,
+                            model.layers[il].attn_k_norm,
+                            model.layers[il].attn_k_norm_b,
+                            LLM_NORM, il);
+                }
+
+                Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+            } else {
+                // compute Q and K and RoPE them
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+            }
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn, gf,
+                    model.layers[il].wo, model.layers[il].bo,
+                    Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            cb(cur, "kqv_out", il);
+
+            if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // re-add the layer input
+            cur = ggml_add(ctx0, cur, inpL);
+
+            // attention layer norm
+            cur = build_norm(cur, model.layers[il].attn_out_norm, model.layers[il].attn_out_norm_b, LLM_NORM, il);
+
+            if (model.layers[il].attn_norm_2 != nullptr) {
+                cur = ggml_add(ctx0, cur, inpL); // re-add the layer input
+                cur = build_norm(cur, model.layers[il].attn_norm_2, model.layers[il].attn_norm_2_b, LLM_NORM, il);
+            }
+
+            ggml_tensor * ffn_inp = cur;
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            if (model.arch == LLM_ARCH_BERT) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+            } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL,                        NULL,
+                        model.layers[il].ffn_gate, NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+            } else {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+            }
+            cb(cur, "ffn_out", il);
+
+            // attentions bypass the intermediate layer
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            // output layer norm
+            cur = build_norm(cur, model.layers[il].layer_out_norm, model.layers[il].layer_out_norm_b, LLM_NORM, il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cb(cur, "result_embd", -1);
+        res->t_embd = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_bloom : public llm_graph_context {
+    llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        inpL = build_norm(inpL,
+                model.tok_norm,
+                model.tok_norm_b,
+                LLM_NORM, -1);
+        cb(inpL, "inp_norm", -1);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // Add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_mpt : public llm_graph_context {
+    llm_build_mpt(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * pos;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        if (model.pos_embd) {
+            // inp_pos - contains the positions
+            ggml_tensor * inp_pos = build_inp_pos();
+            pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
+            cb(pos, "pos_embd", -1);
+
+            inpL = ggml_add(ctx0, inpL, pos);
+            cb(inpL, "inpL", -1);
+        }
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * attn_norm;
+
+            attn_norm = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(attn_norm, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = attn_norm;
+
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                if (model.layers[il].bqkv){
+                    cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                    cb(cur, "bqkv", il);
+                }
+
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(cur, "wqkv_clamped", il);
+                }
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                // Q/K Layernorm
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = build_norm(Qcur,
+                            model.layers[il].attn_q_norm,
+                            model.layers[il].attn_q_norm_b,
+                            LLM_NORM, il);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = build_norm(Kcur,
+                            model.layers[il].attn_k_norm,
+                            model.layers[il].attn_k_norm_b,
+                            LLM_NORM, il);
+                    cb(Kcur, "Kcur", il);
+
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                    cur = build_attn(inp_attn, gf,
+                            model.layers[il].wo, model.layers[il].bo,
+                            Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                } else {
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                    cur = build_attn(inp_attn, gf,
+                            model.layers[il].wo, model.layers[il].bo,
+                            Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                }
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // Add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed forward
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        model.layers[il].ffn_act,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_stablelm : public llm_graph_context {
+    llm_build_stablelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            ggml_tensor * inpSA = cur;
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                cb(Qcur, "Qcur", il);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                cb(Kcur, "Kcur", il);
+
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = build_norm(Qcur,
+                            model.layers[il].attn_q_norm,
+                            NULL,
+                            LLM_NORM, il);
+                    cb(Qcur, "Qcur", il);
+                }
+                if (model.layers[il].attn_k_norm) {
+                    Kcur = build_norm(Kcur,
+                            model.layers[il].attn_k_norm,
+                            NULL,
+                            LLM_NORM, il);
+                    cb(Kcur, "Kcur", il);
+                }
+
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpL  = ggml_get_rows(ctx0,  inpL, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                if (model.layers[il].ffn_norm) {
+                    cur = build_norm(ffn_inp,
+                            model.layers[il].ffn_norm,
+                            model.layers[il].ffn_norm_b,
+                            LLM_NORM, il);
+                    cb(cur, "ffn_norm", il);
+                } else {
+                    // parallel residual
+                    cur = inpSA;
+                }
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_qwen : public llm_graph_context {
+    llm_build_qwen(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                // using mode = 2 for neox mode
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward forward
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_qwen2 : public llm_graph_context {
+    llm_build_qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_qwen2vl : public llm_graph_context {
+    llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        int sections[4];
+        std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_multi(
+                        ctx0,
+                        ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_multi(
+                        ctx0,
+                        ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_qwen2moe : public llm_graph_context {
+    llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            ggml_tensor * moe_out =
+                build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, false,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il);
+            cb(cur, "ffn_moe_out", il);
+
+            // FFN shared expert
+            {
+                ggml_tensor * cur_gate_inp = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
+                cb(cur_gate_inp, "ffn_shexp_gate_inp", il);
+
+                // sigmoid
+                ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp);
+                cb(cur_gate, "ffn_shexp_gate", il);
+
+                ggml_tensor * cur_ffn = build_ffn(cur,
+                        model.layers[il].ffn_up_shexp,   NULL, NULL,
+                        model.layers[il].ffn_gate_shexp, NULL, NULL,
+                        model.layers[il].ffn_down_shexp, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur_ffn, "ffn_shexp", il);
+
+                ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate);
+                cb(ffn_shexp_out, "ffn_shexp_out", il);
+
+                moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out);
+                cb(moe_out, "ffn_out", il);
+
+                cur = moe_out;
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_phi2 : public llm_graph_context {
+    llm_build_phi2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * attn_norm_output;
+        ggml_tensor * ffn_output;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            attn_norm_output = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(attn_norm_output, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = nullptr;
+                ggml_tensor * Kcur = nullptr;
+                ggml_tensor * Vcur = nullptr;
+
+                if (model.layers[il].wqkv) {
+                    cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output);
+                    cb(cur, "wqkv", il);
+
+                    cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                    cb(cur, "bqkv", il);
+
+                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                } else {
+                    Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+                    Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+                    Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+                }
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                // with phi2, we scale the Q to avoid precision issues
+                // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
+                Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur              = ggml_get_rows(ctx0,              cur, inp_out_ids);
+                inpL             = ggml_get_rows(ctx0,             inpL, inp_out_ids);
+                attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
+            }
+
+            // FF
+            {
+                ffn_output = build_ffn(attn_norm_output,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(ffn_output, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_output);
+            cur = ggml_add(ctx0, cur, inpL);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+        cb(cur, "result_output_no_bias", -1);
+
+        cur = ggml_add(ctx0, cur, model.output_b);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_phi3 : public llm_graph_context {
+    llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+
+        for (int il = 0; il < n_layer; ++il) {
+            auto * residual = inpL;
+
+            // self-attention
+            {
+                // rope freq factors for 128k context
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                ggml_tensor* attn_norm_output = build_norm(inpL,
+                        model.layers[il].attn_norm,
+                        model.layers[il].attn_norm_b,
+                        LLM_NORM_RMS, il);
+                cb(attn_norm_output, "attn_norm", il);
+
+                ggml_tensor * Qcur = nullptr;
+                ggml_tensor * Kcur = nullptr;
+                ggml_tensor * Vcur = nullptr;
+
+                if (model.layers[il].wqkv) {
+                    cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output);
+                    cb(cur, "wqkv", il);
+
+                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
+                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
+                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
+                } else {
+                    Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+                    Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+                    Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+                }
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor* inp_out_ids = build_inp_out_ids();
+                cur      = ggml_get_rows(ctx0, cur,      inp_out_ids);
+                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+            }
+
+            cur = ggml_add(ctx0, cur, residual);
+            residual = cur;
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il);
+                cb(cur, "ffn_moe_out", il);
+            }
+
+            cur = ggml_add(ctx0, residual, cur);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        if (model.output_b != nullptr) {
+            cb(cur, "result_output_no_bias", -1);
+            cur = ggml_add(ctx0, cur, model.output_b);
+        }
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_plamo : public llm_graph_context {
+    llm_build_plamo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            ggml_tensor * attention_norm = cur;
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head,    n_tokens), inp_pos, nullptr,
+                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+            ggml_tensor * sa_out = cur;
+
+            cur = attention_norm;
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur    = ggml_get_rows(ctx0,    cur, inp_out_ids);
+                sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids);
+                inpL   = ggml_get_rows(ctx0,   inpL, inp_out_ids);
+            }
+
+            // feed-forward network
+            {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, sa_out);
+            cur = ggml_add(ctx0, cur, inpL);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_gpt2 : public llm_graph_context {
+    llm_build_gpt2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * pos;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
+        cb(pos, "pos_embd", -1);
+
+        inpL = ggml_add(ctx0, inpL, pos);
+        cb(inpL, "inpL", -1);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_codeshell : public llm_graph_context {
+    llm_build_codeshell(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(tmpq, "tmpq", il);
+                cb(tmpk, "tmpk", il);
+                cb(Vcur, "Vcur", il);
+
+                ggml_tensor * Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_orion : public llm_graph_context {
+    llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // norm
+        cur = build_norm(inpL,
+                model.layers[il].attn_norm, model.layers[il].attn_norm_b,
+                LLM_NORM, il);
+        cb(cur, "attn_norm", il);
+
+        // self-attention
+        {
+            // compute Q and K and RoPE them
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+            // if (model.layers[il].bq) {
+            //     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+            //     cb(Qcur, "Qcur", il);
+            // }
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+            // if (model.layers[il].bk) {
+            //     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+            //     cb(Kcur, "Kcur", il);
+            // }
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+            // if (model.layers[il].bv) {
+            //     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+            //     cb(Vcur, "Vcur", il);
+            // }
+
+            Qcur = ggml_rope_ext(
+                ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            cb(Qcur, "Qcur", il);
+
+            Kcur = ggml_rope_ext(
+                ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            cb(Kcur, "Kcur", il);
+
+            cur = build_attn(inp_attn, gf,
+                    model.layers[il].wo, NULL,
+                    Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        }
+
+        if (il == n_layer - 1) {
+            // skip computing output for unused tokens
+            ggml_tensor * inp_out_ids = build_inp_out_ids();
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // feed-forward network
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
+                LLM_NORM, il);
+        cb(cur, "ffn_norm", il);
+
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up,   NULL, NULL,
+                model.layers[il].ffn_gate, NULL, NULL,
+                model.layers[il].ffn_down, NULL, NULL,
+                NULL,
+                LLM_FFN_SILU, LLM_FFN_PAR, il);
+        cb(cur, "ffn_out", il);
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, model.output_norm_b,
+            LLM_NORM, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_internlm2 : public llm_graph_context {
+    llm_build_internlm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_minicpm3 : public llm_graph_context {
+    llm_build_minicpm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        //TODO: if the model varies, these parameters need to be read from the model
+        const int64_t n_embd_base = 256;
+        const float scale_embd  = 12.0f;
+        const float scale_depth = 1.4f;
+        const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
+
+        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+        const uint32_t kv_lora_rank = hparams.n_lora_kv;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // scale the input embeddings
+        inpL = ggml_scale(ctx0, inpL, scale_embd);
+        cb(inpL, "inp_scaled", -1);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                ggml_tensor * q = NULL;
+                // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
+                q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
+                cb(q, "q", il);
+
+                q = build_norm(q,
+                        model.layers[il].attn_q_a_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(q, "q", il);
+
+                // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
+                q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
+                cb(q, "q", il);
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_head * n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        ggml_row_size(q->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
+                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
+
+                // split into {kv_lora_rank, n_tokens}
+                ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        0);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // and {n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        kv_pe_compresseed->nb[1],
+                        ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+                cb(k_pe, "k_pe", il);
+
+                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
+                kv_compressed = ggml_cont(ctx0, kv_compressed);
+                kv_compressed = build_norm(kv_compressed,
+                        model.layers[il].attn_kv_a_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
+                ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
+                cb(kv, "kv", il);
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
+                        ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        0);
+                cb(k_nope, "k_nope", il);
+
+                // and {n_head * n_embd_head_v, n_tokens}
+                ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope)));
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_cont(ctx0, v_states);
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
+                        ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
+                        0);
+                cb(v_states, "v_states", il);
+
+                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
+                q_pe = ggml_rope_ext(
+                        ctx0, q_pe, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(q_pe, "q_pe", il);
+
+                // shared RoPE key
+                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
+                k_pe = ggml_rope_ext(
+                        ctx0, k_pe, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(k_pe, "k_pe", il);
+
+                ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
+                cb(q_states, "q_states", il);
+
+                ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
+                cb(k_states, "k_states", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        q_states, k_states, v_states, nullptr, kq_scale, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // scale_res - scale the hidden states for residual connection
+            const float scale_res = scale_depth/sqrtf(float(n_layer));
+            cur = ggml_scale(ctx0, cur, scale_res);
+            cb(cur, "hidden_scaled", il);
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // scale the hidden states for residual connection
+            cur = ggml_scale(ctx0, cur, scale_res);
+            cb(cur, "hidden_scaled_ffn", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head scaling
+        const float scale_lmhead = float(n_embd_base)/float(n_embd);
+        cur = ggml_scale(ctx0, cur, scale_lmhead);
+        cb(cur, "lmhead_scaling", -1);
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_gemma : public llm_graph_context {
+    llm_build_gemma(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+        cb(inpL, "inp_scaled", -1);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
+                cb(Qcur, "Qcur_scaled", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+            cb(sa_out, "sa_out", il);
+
+            cur = build_norm(sa_out,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, sa_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_gemma2 : public llm_graph_context {
+    llm_build_gemma2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+        cb(inpL, "inp_scaled", -1);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+
+        for (int il = 0; il < n_layer; ++il) {
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
+                switch (model.type) {
+                    case LLM_TYPE_2B:
+                    case LLM_TYPE_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
+                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    default: GGML_ABORT("fatal error");
+                };
+                cb(Qcur, "Qcur_scaled", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+            cb(sa_out, "sa_out", il);
+
+            cur = build_norm(sa_out,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, sa_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        // final logit soft-capping
+        cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+        cur = ggml_tanh(ctx0, cur);
+        cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_gemma3 : public llm_graph_context {
+    llm_build_gemma3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+        if (ubatch.token) {
+            inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+            cb(inpL, "inp_scaled", -1);
+        }
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        // TODO: is causal == true correct? might need some changes
+        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+
+        // "5-to-1 interleaved attention"
+        // 5 layers of local attention followed by 1 layer of global attention
+        static const int sliding_window_pattern = 6;
+
+        for (int il = 0; il < n_layer; ++il) {
+            const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+
+            const float freq_base_l  = is_sliding ? 10000.0f : freq_base;
+            const float freq_scale_l = is_sliding ? 1.0f     : freq_scale;
+
+            // norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens);
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, hparams.f_attention_scale, il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+            cb(sa_out, "sa_out", il);
+
+            cur = build_norm(sa_out,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, sa_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+// TODO: move up next to build_starcoder
+struct llm_build_starcoder2 : public llm_graph_context {
+    llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                    NULL,                      NULL,                        NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    NULL,
+                    LLM_FFN_GELU, LLM_FFN_SEQ, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_mamba : public llm_graph_context {
+    const llama_model & model;
+
+    llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        // {n_embd, n_tokens}
+        inpL = build_inp_embd(model.tok_embd);
+
+        ggml_tensor * state_copy = build_inp_s_copy();
+        ggml_tensor * state_mask = build_inp_s_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il);
+            cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // residual
+            cur = ggml_add(ctx0, cur, inpL);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        // final rmsnorm
+        cur = build_norm(inpL,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+    // TODO: split
+    ggml_tensor * build_mamba_layer(
+             ggml_cgraph * gf,
+             ggml_tensor * cur,
+             ggml_tensor * state_copy,
+             ggml_tensor * state_mask,
+      const llama_ubatch & ubatch,
+                     int   il) const {
+        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+        const auto kv_head = kv_self->head;
+
+        const int64_t d_conv  = hparams.ssm_d_conv;
+        const int64_t d_inner = hparams.ssm_d_inner;
+        const int64_t d_state = hparams.ssm_d_state;
+        const int64_t dt_rank = hparams.ssm_dt_rank;
+        const int64_t n_seqs  = ubatch.n_seqs;
+        // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
+        const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
+        // Use the same RMS norm as the final layer norm
+        const float norm_rms_eps = hparams.f_norm_rms_eps;
+
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+        ggml_tensor * conv_states_all = kv_self->k_l[il];
+        ggml_tensor * ssm_states_all  = kv_self->v_l[il];
+
+        // (ab)using the KV cache to store the states
+        ggml_tensor * conv = build_copy_mask_state(
+                gf, conv_states_all, state_copy, state_mask,
+                hparams.n_embd_k_s(), n_seqs);
+        conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
+        ggml_tensor * ssm = build_copy_mask_state(
+                gf, ssm_states_all, state_copy, state_mask,
+                hparams.n_embd_v_s(), n_seqs);
+        ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
+
+        // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+        cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+        // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
+        ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, cur);
+        // split the above in two
+        // => {d_inner, n_seq_tokens, n_seqs}
+        ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
+        ggml_tensor * z = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz));
+
+        // conv
+        {
+            // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
+            ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
+
+            // copy last (d_conv - 1) columns back into the state cache
+            ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
+
+            ggml_build_forward_expand(gf,
+                ggml_cpy(ctx0, last_conv,
+                    ggml_view_1d(ctx0, conv_states_all,
+                        (d_conv - 1)*(d_inner)*(n_seqs),
+                        kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
+
+            // 1D convolution
+            // The equivalent is to make a self-overlapping view of conv_x
+            // over d_conv columns at each stride in the 3rd dimension,
+            // then element-wise multiply that with the conv1d weight,
+            // then sum the elements of each row,
+            // (the last two steps are a dot product over rows (also doable with mul_mat))
+            // then permute away the ne[0] dimension,
+            // and then you're left with the resulting x tensor.
+            // For simultaneous sequences, all sequences need to have the same length.
+            x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
+
+            // bias
+            x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
+
+            x = ggml_silu(ctx0, x);
+        }
+
+        // ssm
+        {
+            // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
+            ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x);
+            // split
+            ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
+            ggml_tensor * B  = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
+            ggml_tensor * C  = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
+
+            // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
+            if (ssm_dt_b_c_rms) {
+                dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
+                B = ggml_rms_norm(ctx0, B, norm_rms_eps);
+                C = ggml_rms_norm(ctx0, C, norm_rms_eps);
+            }
+
+            // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
+            dt = build_lora_mm(model.layers[il].ssm_dt, dt);
+            dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
+
+            // Custom operator to optimize the parallel associative scan
+            // as described in the Annex D of the Mamba paper.
+            // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+            ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C);
+
+            // store last states
+            ggml_build_forward_expand(gf,
+                ggml_cpy(ctx0,
+                    ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
+                    ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
+
+            ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
+
+            // TODO: skip computing output earlier for unused tokens
+
+            // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
+            y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
+            y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
+
+            // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+            cur = build_lora_mm(model.layers[il].ssm_out, y);
+        }
+
+        // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+        //cb(cur, "mamba_out", il);
+
+        return cur;
+    }
+};
+
+struct llm_build_command_r : public llm_graph_context {
+    llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        const float f_logit_scale = hparams.f_logit_scale;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+            ggml_tensor * ffn_inp = cur;
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
+                            ggml_element_size(Qcur) * n_embd_head,
+                            ggml_element_size(Qcur) * n_embd_head * n_head,
+                            0);
+                    cb(Qcur, "Qcur", il);
+                    Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
+                            ggml_element_size(Kcur) * n_embd_head,
+                            ggml_element_size(Kcur) * n_embd_head * n_head_kv,
+                            0);
+                    cb(Kcur, "Kcur", il);
+
+                    Qcur = build_norm(Qcur,
+                            model.layers[il].attn_q_norm,
+                            NULL,
+                            LLM_NORM, il);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = build_norm(Kcur,
+                            model.layers[il].attn_k_norm,
+                            NULL,
+                            LLM_NORM, il);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur     = ggml_get_rows(ctx0,     cur, inp_out_ids);
+                inpL    = ggml_get_rows(ctx0,    inpL, inp_out_ids);
+                ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
+            }
+
+            ggml_tensor * attn_out = cur;
+
+            // feed-forward network
+            {
+                cur = build_ffn(ffn_inp,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // add together residual + FFN + self-attention
+            cur = ggml_add(ctx0, cur, inpL);
+            cur = ggml_add(ctx0, cur, attn_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        if (f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, f_logit_scale);
+        }
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_cohere2 : public llm_graph_context {
+    llm_build_cohere2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        const float f_logit_scale = hparams.f_logit_scale;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+
+        // sliding window switch pattern
+        const int32_t sliding_window_pattern = 4;
+
+        for (int il = 0; il < n_layer; ++il) {
+            // three layers sliding window attention (window size 4096) and ROPE
+            // fourth layer uses global attention without positional embeddings
+            const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+
+            // norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+            ggml_tensor * ffn_inp = cur;
+
+            // self-attention
+            {
+                // rope freq factors for 128k context
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                if (is_sliding) {
+                    Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
+                            beta_fast, beta_slow);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                            rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+                            attn_factor, beta_fast, beta_slow);
+                    cb(Kcur, "Kcur", il);
+                } else {
+                    // For non-sliding layers, just reshape without applying RoPE
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur     = ggml_get_rows(ctx0, cur, inp_out_ids);
+                inpL    = ggml_get_rows(ctx0, inpL, inp_out_ids);
+                ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
+            }
+
+            ggml_tensor * attn_out = cur;
+
+            // feed-forward network
+            {
+                cur = build_ffn(ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
+                        NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
+                        il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // add together residual + FFN + self-attention
+            cur = ggml_add(ctx0, cur, inpL);
+            cur = ggml_add(ctx0, cur, attn_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        if (f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, f_logit_scale);
+        }
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+// ref: https://allenai.org/olmo
+// based on the original build_llama() function, changes:
+//   * non-parametric layer norm
+//   * clamp qkv
+//   * removed bias
+//   * removed MoE
+struct llm_build_olmo : public llm_graph_context {
+    llm_build_olmo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    NULL, NULL,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, nullptr,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    NULL, NULL,
+                    LLM_NORM, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                NULL, NULL,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_olmo2 : public llm_graph_context {
+    llm_build_olmo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = inpL;
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_ffn(ffn_inp,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+// based on the build_qwen2moe() function, changes:
+//   * removed shared experts
+//   * removed bias
+//   * added q, k norm
+struct llm_build_olmoe : public llm_graph_context {
+    llm_build_olmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, false,
+                    false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    il);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_openelm : public llm_graph_context {
+    llm_build_openelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const int64_t n_head    = hparams.n_head(il);
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_head_qkv = 2*n_head_kv + n_head;
+
+            cur = inpL;
+            ggml_tensor * residual = cur;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0));
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head));
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+                cb(Vcur, "Vcur", il);
+
+                Qcur = build_norm(Qcur,
+                        model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = build_norm(Kcur,
+                        model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens);
+                cb(Qcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+                cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        // norm
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_gptneox : public llm_graph_context {
+    llm_build_gptneox(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // ffn
+            if (hparams.use_par_res) {
+                // attention and ffn are computed in parallel
+                // x = x + attn(ln1(x)) + ffn(ln2(x))
+
+                ggml_tensor * attn_out = cur;
+
+                cur = build_norm(inpL,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+
+                cur = ggml_add(ctx0, cur, inpL);
+                cb(cur, "ffn_out", il);
+
+                cur = ggml_add(ctx0, cur, attn_out);
+
+                cur = build_cvec(cur, il);
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            } else {
+                // attention and ffn are computed sequentially
+                // x = x + attn(ln1(x))
+                // x = x + ffn(ln2(x))
+
+                ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+                cb(ffn_inp, "ffn_inp", il);
+
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+
+                cur = ggml_add(ctx0, cur, ffn_inp);
+
+                cur = build_cvec(cur, il);
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            }
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_arctic : public llm_graph_context {
+    llm_build_arctic(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            ggml_tensor * ffn_out = ggml_add(ctx0, cur, ffn_inp);
+            cb(ffn_out, "ffn_out", il);
+
+            // MoE
+            cur = build_norm(inpSA,
+                    model.layers[il].ffn_norm_exps, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm_exps", il);
+
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, true,
+                    false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    il);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_out);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_deepseek : public llm_graph_context {
+    llm_build_deepseek(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                ggml_tensor * moe_out =
+                    build_moe_ffn(cur,
+                            model.layers[il].ffn_gate_inp,
+                            model.layers[il].ffn_up_exps,
+                            model.layers[il].ffn_gate_exps,
+                            model.layers[il].ffn_down_exps,
+                            nullptr,
+                            n_expert, n_expert_used,
+                            LLM_FFN_SILU, false,
+                            false, hparams.expert_weights_scale,
+                            LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                            il);
+                cb(moe_out, "ffn_moe_out", il);
+
+                // FFN shared expert
+                {
+                    ggml_tensor * ffn_shexp = build_ffn(cur,
+                            model.layers[il].ffn_up_shexp,   NULL, NULL,
+                            model.layers[il].ffn_gate_shexp, NULL, NULL,
+                            model.layers[il].ffn_down_shexp, NULL, NULL,
+                            NULL,
+                            LLM_FFN_SILU, LLM_FFN_PAR, il);
+                    cb(ffn_shexp, "ffn_shexp", il);
+
+                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                    cb(cur, "ffn_out", il);
+                }
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_deepseek2 : public llm_graph_context {
+    llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        bool is_lite = (hparams.n_layer == 27);
+
+        // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
+        // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
+        const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
+        const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k));
+        const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
+
+        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+        const uint32_t kv_lora_rank = hparams.n_lora_kv;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        // {n_embd, n_tokens}
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                ggml_tensor * q = NULL;
+                if (!is_lite) {
+                    // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
+                    q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
+                    cb(q, "q", il);
+
+                    q = build_norm(q,
+                            model.layers[il].attn_q_a_norm, NULL,
+                            LLM_NORM_RMS, il);
+                    cb(q, "q", il);
+
+                    // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
+                    q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
+                    cb(q, "q", il);
+                } else {
+                    q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                    cb(q, "q", il);
+                }
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_head * n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        ggml_row_size(q->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
+                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
+
+                // split into {kv_lora_rank, n_tokens}
+                ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        0);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // and {n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        kv_pe_compresseed->nb[1],
+                        ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+                cb(k_pe, "k_pe", il);
+
+                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
+                kv_compressed = ggml_cont(ctx0, kv_compressed);
+                kv_compressed = build_norm(kv_compressed,
+                        model.layers[il].attn_kv_a_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
+                ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
+                cb(kv, "kv", il);
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
+                        ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        0);
+                cb(k_nope, "k_nope", il);
+
+                // and {n_head * n_embd_head_v, n_tokens}
+                ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope)));
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_cont(ctx0, v_states);
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
+                        ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
+                        0);
+                cb(v_states, "v_states", il);
+
+                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
+                q_pe = ggml_rope_ext(
+                        ctx0, q_pe, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
+                        );
+                cb(q_pe, "q_pe", il);
+
+                // shared RoPE key
+                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
+                k_pe = ggml_rope_ext(
+                        ctx0, k_pe, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
+                        );
+                cb(k_pe, "k_pe", il);
+
+                ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
+                cb(q_states, "q_states", il);
+
+                ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
+                cb(k_states, "k_states", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        q_states, k_states, v_states, nullptr, kq_scale, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                ggml_tensor * moe_out =
+                    build_moe_ffn(cur,
+                            model.layers[il].ffn_gate_inp,
+                            model.layers[il].ffn_up_exps,
+                            model.layers[il].ffn_gate_exps,
+                            model.layers[il].ffn_down_exps,
+                            model.layers[il].ffn_exp_probs_b,
+                            n_expert, n_expert_used,
+                            LLM_FFN_SILU, hparams.expert_weights_norm,
+                            true, hparams.expert_weights_scale,
+                            (llama_expert_gating_func_type) hparams.expert_gating_func,
+                            il);
+                cb(moe_out, "ffn_moe_out", il);
+
+                // FFN shared expert
+                {
+                    ggml_tensor * ffn_shexp = build_ffn(cur,
+                            model.layers[il].ffn_up_shexp,   NULL, NULL,
+                            model.layers[il].ffn_gate_shexp, NULL, NULL,
+                            model.layers[il].ffn_down_shexp, NULL, NULL,
+                            NULL,
+                            LLM_FFN_SILU, LLM_FFN_PAR, il);
+                    cb(ffn_shexp, "ffn_shexp", il);
+
+                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                    cb(cur, "ffn_out", il);
+                }
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_bitnet : public llm_graph_context {
+    llm_build_bitnet(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                if (model.layers[il].wq_scale) {
+                    Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
+                }
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                // B1.K
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                if (model.layers[il].wk_scale) {
+                    Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
+                }
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                // B1.V
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                if (model.layers[il].wv_scale) {
+                    Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
+                }
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        NULL, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+
+                cur = build_norm(cur,
+                        model.layers[il].attn_sub_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "attn_sub_norm", il);
+
+                cur = build_lora_mm(model.layers[il].wo, cur);
+                if (model.layers[il].wo_scale) {
+                    cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
+                }
+                if (model.layers[il].bo) {
+                    cur = ggml_add(ctx0, cur, model.layers[il].bo);
+                }
+                cb(cur, "attn_o_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward forward
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, model.layers[il].ffn_up_scale,
+                    model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale,
+                    NULL,                      NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_sub_out", il);
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_sub_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_sub_norm", il);
+
+            cur = build_lora_mm(model.layers[il].ffn_down, cur);
+            if (model.layers[il].ffn_down_scale) {
+                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
+            }
+            cb(cur, "ffn_down", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        // FIXME: do not use model.tok_embd directly, duplicate as model.output
+        cur = build_lora_mm(model.tok_embd, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_t5_enc : public llm_graph_context {
+    llm_build_t5_enc(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc();
+
+        auto * inp_attn = build_attn_inp_no_cache();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm_enc, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
+                ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo_enc, nullptr,
+                        Qcur, Kcur, Vcur, kq_b, 1.0f, il);
+                cb(cur, "kqv_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm_enc, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                // T5 uses relu, flan-T5 uses gelu-gated
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up_enc,   NULL, NULL,
+                        model.layers[il].ffn_gate_enc, NULL, NULL,
+                        model.layers[il].ffn_down_enc, NULL, NULL,
+                        NULL,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR  : LLM_FFN_SEQ,
+                        il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        cb(cur, "result_embd", -1);
+
+        cur = build_norm(cur,
+                model.output_norm_enc, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_t5_dec : public llm_graph_context {
+    llm_build_t5_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        //const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        ggml_tensor * embd_enc       = build_inp_cross_embd();
+        ggml_tensor * pos_bucket_dec = build_inp_pos_bucket_dec();
+
+        const int64_t n_outputs_enc = embd_enc->ne[1];
+
+        auto * inp_attn_self  = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn_cross = build_attn_inp_cross();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
+                ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b);
+
+                cur = build_attn(inp_attn_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, kq_b, 1.0f, il);
+                cb(cur, "kqv_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, inpSA);
+            cb(cur, "cross_inp", il);
+
+            ggml_tensor * inpCA = cur;
+
+            // norm
+            cur = build_norm(cur,
+                    model.layers[il].attn_norm_cross, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm_cross", il);
+
+            // cross-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_cross, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_cross, embd_enc);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_cross, embd_enc);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc);
+
+                cur = build_attn(inp_attn_cross, gf,
+                        model.layers[il].wo_cross, nullptr,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+                cb(cur, "kqv_out", il);
+
+                //ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+                //ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+
+                //ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+                //cb(kq, "kq", il);
+
+                //kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
+                //cb(kq, "kq_soft_max_ext", il);
+
+                //ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
+                //cb(v, "v", il);
+
+                //ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
+                //cb(kqv, "kqv", il);
+
+                //ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                //cb(kqv_merged, "kqv_merged", il);
+
+                //cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                //cb(cur, "kqv_merged_cont", il);
+
+                //ggml_build_forward_expand(gf, cur);
+
+                //cur = build_lora_mm(model.layers[il].wo_cross, cur);
+                //cb(cur, "kqv_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                // T5 uses relu, flan-T5 uses gelu-gated
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
+                        il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        cb(cur, "result_embd", -1);
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_jais : public llm_graph_context {
+    llm_build_jais(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/float(n_embd_head), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            inpL = ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_chatglm : public llm_graph_context {
+    llm_build_chatglm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = nullptr;
+                ggml_tensor * Kcur = nullptr;
+                ggml_tensor * Vcur = nullptr;
+
+                if (model.layers[il].wqkv == nullptr) {
+                    Qcur = build_lora_mm(model.layers[il].wq, cur);
+                    if (model.layers[il].bq) {
+                        Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    }
+                    Kcur = build_lora_mm(model.layers[il].wk, cur);
+                    if (model.layers[il].bk) {
+                        Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    }
+                    Vcur = build_lora_mm(model.layers[il].wv, cur);
+                    if (model.layers[il].bv) {
+                        Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    }
+                } else {
+                    cur = build_lora_mm(model.layers[il].wqkv, cur);
+                    cb(cur, "wqkv", il);
+                    if (model.layers[il].bqkv) {
+                        cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                        cb(cur, "bqkv", il);
+                    }
+                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                }
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // Add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+
+            }
+
+            inpL = ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_nemotron : public llm_graph_context {
+    llm_build_nemotron(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        //GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm,
+                    model.layers[il].ffn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                    NULL,                      NULL,                        NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    NULL,
+                    LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_exaone : public llm_graph_context {
+    llm_build_exaone(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_rwkv6_base : public llm_graph_context {
+    const llama_model & model;
+
+    llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
+    }
+
+    ggml_tensor * build_rwkv6_channel_mix(
+            const llama_layer * layer,
+            ggml_tensor * cur,
+            ggml_tensor * x_prev,
+            llm_arch arch) const {
+        ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
+        switch (arch) {
+            case LLM_ARCH_RWKV6:
+                {
+                    ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
+                    ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur);
+
+                    ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr));
+                    ggml_tensor * k = ggml_sqr(
+                            ctx0,
+                            ggml_relu(
+                                ctx0,
+                                build_lora_mm(layer->channel_mix_key, xk)
+                                )
+                            );
+                    cur = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k));
+                } break;
+            default:
+                GGML_ABORT("fatal error");
+        }
+
+        return cur;
+    }
+
+    ggml_tensor * build_rwkv6_time_mix(
+            ggml_cgraph * gf,
+            ggml_tensor * cur,
+            ggml_tensor * x_prev,
+            ggml_tensor * state_copy,
+            ggml_tensor * state_mask,
+            const llama_ubatch & ubatch,
+            int   il) const {
+        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+        const auto n_tokens = ubatch.n_tokens;
+        const auto n_seqs = ubatch.n_seqs;
+        const auto n_embd = hparams.n_embd;
+        const auto head_size = hparams.wkv_head_size;
+        const auto n_head = n_embd / head_size;
+        const auto n_head_kv = hparams.n_head_kv(il);
+
+        const auto kv_head = kv_self->head;
+
+        const auto & layer = model.layers[il];
+
+        bool is_qrwkv = layer.time_mix_first == nullptr;
+
+        ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
+        ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_x), cur);
+
+        xxx = ggml_reshape_4d(
+                ctx0,
+                ggml_tanh(
+                    ctx0,
+                    ggml_mul_mat(ctx0, layer.time_mix_w1, xxx)
+                    ),
+                layer.time_mix_w1->ne[1] / 5, 1, 5, n_tokens
+                );
+
+        xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2));
+
+        xxx = ggml_mul_mat(
+                ctx0,
+                ggml_reshape_4d(
+                    ctx0,
+                    layer.time_mix_w2,
+                    layer.time_mix_w2->ne[0], layer.time_mix_w2->ne[1], 1, 5
+                    ),
+                xxx
+                );
+
+        ggml_tensor *xw, *xk, *xv, *xr, *xg;
+        if (layer.time_mix_lerp_fused) {
+            // fusing these weights makes some performance improvement
+            sx  = ggml_reshape_3d(ctx0, sx,  n_embd, 1, n_tokens);
+            cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
+            xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer.time_mix_lerp_fused), sx), cur);
+            xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+            xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+            xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+            xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+            xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+        } else {
+            // for backward compatibility
+            xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+            xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+            xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+            xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+            xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+
+            xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer.time_mix_lerp_w), sx), cur);
+            xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer.time_mix_lerp_k), sx), cur);
+            xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer.time_mix_lerp_v), sx), cur);
+            xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer.time_mix_lerp_r), sx), cur);
+            xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer.time_mix_lerp_g), sx), cur);
+        }
+
+        ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr);
+        ggml_tensor * k = build_lora_mm(layer.time_mix_key,        xk);
+        ggml_tensor * v = build_lora_mm(layer.time_mix_value,      xv);
+        if (layer.time_mix_receptance_b) {
+            r = ggml_add(ctx0, r, layer.time_mix_receptance_b);
+        }
+        if (layer.time_mix_key_b) {
+            k = ggml_add(ctx0, k, layer.time_mix_key_b);
+        }
+        if (layer.time_mix_value_b) {
+            v = ggml_add(ctx0, v, layer.time_mix_value_b);
+        }
+
+        ggml_tensor * g = build_lora_mm(layer.time_mix_gate, xg);
+        if (is_qrwkv) {
+            g = ggml_sigmoid(ctx0, g);
+        } else {
+            g = ggml_silu(ctx0, g);
+        }
+
+        if (n_head_kv != 0 && n_head_kv != n_head) {
+            GGML_ASSERT(n_head % n_head_kv == 0);
+            k = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens);
+            v = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens);
+            ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens);
+            k = ggml_repeat(ctx0, k, tmp);
+            v = ggml_repeat(ctx0, v, tmp);
+        }
+
+        k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens);
+        v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens);
+        r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens);
+
+        ggml_tensor * w = ggml_mul_mat(
+                ctx0,
+                layer.time_mix_decay_w2,
+                ggml_tanh(
+                    ctx0,
+                    ggml_mul_mat(ctx0, layer.time_mix_decay_w1, xw)
+                    )
+                );
+
+        w = ggml_add(ctx0, w, layer.time_mix_decay);
+        w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w)));
+        w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens);
+
+        if (is_qrwkv) {
+            // k = k * (1 - w)
+            k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
+        }
+
+        ggml_tensor * wkv_state = build_copy_mask_state(
+                gf, kv_self->v_l[il], state_copy, state_mask,
+                hparams.n_embd_v_s(), n_seqs);
+
+        ggml_tensor * wkv_output;
+        if (is_qrwkv) {
+            wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f));
+        } else {
+            wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer.time_mix_first, w, wkv_state);
+        }
+        cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
+        wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
+
+        ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_state,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self->v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
+                        )
+                    )
+                );
+
+        if (!is_qrwkv) {
+            // group norm with head_count groups
+            cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens);
+            cur = ggml_norm(ctx0, cur, 64e-5f);
+
+            // Convert back to regular vectors.
+            cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b);
+        } else {
+            cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        }
+
+        cur = ggml_mul(ctx0, cur, g);
+        cur = build_lora_mm(layer.time_mix_output, cur);
+
+        return cur;
+    }
+};
+
+struct llm_build_rwkv6 : public llm_build_rwkv6_base {
+    llm_build_rwkv6(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
+        GGML_ASSERT(hparams.token_shift_count == 2);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+        inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
+
+        ggml_tensor * state_copy = build_inp_s_copy();
+        ggml_tensor * state_mask = build_inp_s_mask();
+
+        const auto n_embd = hparams.n_embd;
+        const auto n_seq_tokens = ubatch.n_seq_tokens;
+        const auto n_seqs = ubatch.n_seqs;
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(
+                    gf, state_copy, state_mask, ubatch, il
+                    );
+
+            ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
+            ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
+
+            ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il);
+            cb(att_norm, "attn_norm", il);
+
+            ggml_tensor * x_prev = ggml_concat(
+                    ctx0,
+                    att_shift,
+                    ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
+                    1
+                    );
+
+            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il);
+            cb(ffn_norm, "ffn_norm", il);
+
+            x_prev = ggml_concat(
+                    ctx0,
+                    ffn_shift,
+                    ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0),
+                    1
+                    );
+
+            cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            token_shift = ggml_concat(ctx0,
+                    ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)),
+                    ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)),
+                    1
+                    );
+            ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
+
+            if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
+                cur = ggml_scale(ctx0, cur, 0.5F);
+            }
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
+struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
+    llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
+        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        ggml_tensor * state_copy = build_inp_s_copy();
+        ggml_tensor * state_mask = build_inp_s_mask();
+
+        const auto n_embd = hparams.n_embd;
+        const auto n_seq_tokens = ubatch.n_seq_tokens;
+        const auto n_seqs = ubatch.n_seqs;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(
+                    gf, state_copy, state_mask, ubatch, il
+                    );
+
+            ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
+            cb(att_norm, "attn_norm", il);
+
+            ggml_tensor * x_prev = ggml_concat(
+                    ctx0,
+                    token_shift,
+                    ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
+                    1
+                    );
+
+            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
+
+            token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
+            ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+// ref: https://github.com/facebookresearch/chameleon
+// based on the original build_llama() function, changes:
+//   * qk-norm
+//   * swin-norm
+//   * removed bias
+//   * removed MoE
+struct llm_build_chameleon : public llm_graph_context {
+    llm_build_chameleon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            if (hparams.swin_norm) {
+                cur = inpL;
+            } else {
+                cur = build_norm(inpL,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "attn_norm", il);
+            }
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
+                            ggml_element_size(Qcur) * n_embd_head,
+                            ggml_element_size(Qcur) * n_embd_head * n_head,
+                            0);
+                    cb(Qcur, "Qcur", il);
+
+                    Qcur = build_norm(Qcur,
+                            model.layers[il].attn_q_norm,
+                            model.layers[il].attn_q_norm_b,
+                            LLM_NORM, il);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                if (model.layers[il].attn_k_norm) {
+                    Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
+                            ggml_element_size(Kcur) * n_embd_head,
+                            ggml_element_size(Kcur) * n_embd_head * n_head_kv,
+                            0);
+                    cb(Kcur, "Kcur", il);
+
+                    Kcur = build_norm(Kcur,
+                            model.layers[il].attn_k_norm,
+                            model.layers[il].attn_k_norm_b,
+                            LLM_NORM, il);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, nullptr,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+
+                if (hparams.swin_norm) {
+                    cur = build_norm(cur,
+                            model.layers[il].attn_norm, NULL,
+                            LLM_NORM_RMS, il);
+                }
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            if (!hparams.swin_norm) {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+            }
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            if (hparams.swin_norm) {
+                cur = build_norm(cur,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+        cb(cur, "result_output_with_img_logits", -1);
+
+        // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs.
+        // Needs to be removed once image outputs are supported.
+        int img_token_end_idx = 8196;
+        int img_token_start_idx = 4;
+        int num_img_tokens = img_token_end_idx - img_token_start_idx;
+        // creates 1d tensor of size num_img_tokens and values -FLT_MAX,
+        // which ensures that text token values are always at least larger than image token values
+        ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens);
+        img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX);
+        cb(img_logits, "img_logits", -1);
+
+        cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_wavtokenizer_dec : public llm_graph_context {
+    llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, inpL));
+
+        cur = ggml_conv_1d_ph(ctx0, model.conv1d, cur, 1, 1);
+        cur = ggml_add(ctx0, cur, model.conv1d_b);
+
+        // posnet
+        for (uint32_t il = 0; il < hparams.posnet.n_layer; ++il) {
+            const auto & layer = model.layers[il].posnet;
+
+            inpL = cur;
+
+            switch (il) {
+                case 0:
+                case 1:
+                case 3:
+                case 4:
+                    {
+                        cur = build_norm(cur,
+                                layer.norm1,
+                                layer.norm1_b,
+                                LLM_NORM_GROUP, 0);
+
+                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.conv1, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.conv1_b);
+
+                        cur = build_norm(cur,
+                                layer.norm2,
+                                layer.norm2_b,
+                                LLM_NORM_GROUP, 0);
+
+                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.conv2, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.conv2_b);
+
+                        cur = ggml_add(ctx0, cur, inpL);
+                    } break;
+                case 2:
+                    {
+                        cur = build_norm(cur,
+                                layer.attn_norm,
+                                layer.attn_norm_b,
+                                LLM_NORM_GROUP, 0);
+
+                        ggml_tensor * q;
+                        ggml_tensor * k;
+                        ggml_tensor * v;
+
+                        q = ggml_conv_1d_ph(ctx0, layer.attn_q, cur, 1, 1);
+                        k = ggml_conv_1d_ph(ctx0, layer.attn_k, cur, 1, 1);
+                        v = ggml_conv_1d_ph(ctx0, layer.attn_v, cur, 1, 1);
+
+                        q = ggml_add(ctx0, q, layer.attn_q_b);
+                        k = ggml_add(ctx0, k, layer.attn_k_b);
+                        v = ggml_add(ctx0, v, layer.attn_v_b);
+
+                        q = ggml_cont(ctx0, ggml_transpose(ctx0, q));
+                        k = ggml_cont(ctx0, ggml_transpose(ctx0, k));
+
+                        ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+
+                        kq = ggml_soft_max_ext(ctx0, kq, nullptr, 1.0f/sqrtf(float(hparams.posnet.n_embd)), 0.0f);
+
+                        cur = ggml_mul_mat(ctx0, kq, v);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.attn_o, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.attn_o_b);
+
+                        cur = ggml_add(ctx0, cur, inpL);
+                    } break;
+                case 5:
+                    {
+                        cur = build_norm(cur,
+                                layer.norm,
+                                layer.norm_b,
+                                LLM_NORM_GROUP, 0);
+                    } break;
+                default: GGML_ABORT("unknown posnet layer");
+            };
+        }
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        cur = build_norm(cur,
+                model.tok_norm,
+                model.tok_norm_b,
+                LLM_NORM, -1);
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        inpL = cur;
+
+        // convnext
+        for (uint32_t il = 0; il < hparams.convnext.n_layer; ++il) {
+            const auto & layer = model.layers[il].convnext;
+
+            cur = inpL;
+
+            cur = ggml_conv_1d_dw_ph(ctx0, layer.dw, cur, 1, 1);
+            cur = ggml_add(ctx0, cur, layer.dw_b);
+
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+            cur = build_norm(cur,
+                    layer.norm,
+                    layer.norm_b,
+                    LLM_NORM, -1);
+
+            cur = build_ffn(cur,
+                    layer.pw1, layer.pw1_b, NULL,
+                    NULL,      NULL,        NULL,
+                    layer.pw2, layer.pw2_b, NULL,
+                    NULL,
+                    LLM_FFN_GELU, LLM_FFN_SEQ, il);
+
+            cur = ggml_mul(ctx0, cur, layer.gamma);
+
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+            inpL = ggml_add(ctx0, cur, inpL);
+        }
+
+        cur = inpL;
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        cur = build_norm(cur,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cur = ggml_add(ctx0, cur, model.output_b);
+
+        cb(cur, "result_embd", -1);
+        res->t_embd = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+llama_memory_i * llama_model::create_memory() const {
+    llama_memory_i * res;
+
+    switch (arch) {
+        case LLM_ARCH_RWKV6:
+        case LLM_ARCH_RWKV6QWEN2:
+        case LLM_ARCH_MAMBA:
+            {
+                res = new llama_kv_cache_unified(hparams, {
+                    /*.get_rope_factors =*/ nullptr
+                });
+            } break;
+        default:
+            {
+                res = new llama_kv_cache_unified(hparams, {
+                    /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
+                        // choose long/short freq factors based on the context size
+                        if (layers[il].rope_freqs != nullptr) {
+                            return layers[il].rope_freqs;
+                        }
+
+                        if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
+                            return layers[il].rope_long;
+                        }
+
+                        return layers[il].rope_short;
+                    }
+                });
+            }
+    }
+
+    return res;
+}
+
+llm_graph_result_ptr llama_model::build_graph(
+        const llm_graph_params & params,
+                   ggml_cgraph * gf,
+                llm_graph_type   type) const {
+    std::unique_ptr<llm_graph_context> llm;
+
+    switch (arch) {
+        case LLM_ARCH_LLAMA:
+        case LLM_ARCH_MINICPM:
+        case LLM_ARCH_GRANITE:
+        case LLM_ARCH_GRANITE_MOE:
+            {
+                llm = std::make_unique<llm_build_llama>(*this, params, gf);
+            } break;
+        case LLM_ARCH_DECI:
+            {
+                llm = std::make_unique<llm_build_deci>(*this, params, gf);
+            } break;
+        case LLM_ARCH_BAICHUAN:
+            {
+                llm = std::make_unique<llm_build_baichuan>(*this, params, gf);
+            } break;
+        case LLM_ARCH_FALCON:
+            {
+                llm = std::make_unique<llm_build_falcon>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GROK:
+            {
+                llm = std::make_unique<llm_build_grok>(*this, params, gf);
+            } break;
+        case LLM_ARCH_STARCODER:
+            {
+                llm = std::make_unique<llm_build_starcoder>(*this, params, gf);
+            } break;
+        case LLM_ARCH_REFACT:
+            {
+                llm = std::make_unique<llm_build_refact>(*this, params, gf);
+            } break;
+        case LLM_ARCH_BERT:
+        case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_NOMIC_BERT:
+            {
+                llm = std::make_unique<llm_build_bert>(*this, params, gf);
+            } break;
+        case LLM_ARCH_BLOOM:
+            {
+                llm = std::make_unique<llm_build_bloom>(*this, params, gf);
+            } break;
+        case LLM_ARCH_MPT:
+            {
+                llm = std::make_unique<llm_build_mpt>(*this, params, gf);
+            } break;
+        case LLM_ARCH_STABLELM:
+            {
+                llm = std::make_unique<llm_build_stablelm>(*this, params, gf);
+            } break;
+        case LLM_ARCH_QWEN:
+            {
+                llm = std::make_unique<llm_build_qwen>(*this, params, gf);
+            } break;
+        case LLM_ARCH_QWEN2:
+            {
+                llm = std::make_unique<llm_build_qwen2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_QWEN2VL:
+            {
+                llm = std::make_unique<llm_build_qwen2vl>(*this, params, gf);
+            } break;
+        case LLM_ARCH_QWEN2MOE:
+            {
+                llm = std::make_unique<llm_build_qwen2moe>(*this, params, gf);
+            } break;
+        case LLM_ARCH_PHI2:
+            {
+                llm = std::make_unique<llm_build_phi2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
+            {
+                llm = std::make_unique<llm_build_phi3>(*this, params, gf);
+            } break;
+        case LLM_ARCH_PLAMO:
+            {
+                llm = std::make_unique<llm_build_plamo>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GPT2:
+            {
+                llm = std::make_unique<llm_build_gpt2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_CODESHELL:
+            {
+                llm = std::make_unique<llm_build_codeshell>(*this, params, gf);
+            } break;
+        case LLM_ARCH_ORION:
+            {
+                llm = std::make_unique<llm_build_orion>(*this, params, gf);
+            } break;
+        case LLM_ARCH_INTERNLM2:
+            {
+                llm = std::make_unique<llm_build_internlm2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_MINICPM3:
+            {
+                llm = std::make_unique<llm_build_minicpm3>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GEMMA:
+            {
+                llm = std::make_unique<llm_build_gemma>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GEMMA2:
+            {
+                llm = std::make_unique<llm_build_gemma2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GEMMA3:
+            {
+                llm = std::make_unique<llm_build_gemma3>(*this, params, gf);
+            } break;
+        case LLM_ARCH_STARCODER2:
+            {
+                llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_MAMBA:
+            {
+                llm = std::make_unique<llm_build_mamba>(*this, params, gf);
+            } break;
+        case LLM_ARCH_XVERSE:
+            {
+                llm = std::make_unique<llm_build_xverse>(*this, params, gf);
+            } break;
+        case LLM_ARCH_COMMAND_R:
+            {
+                llm = std::make_unique<llm_build_command_r>(*this, params, gf);
+            } break;
+        case LLM_ARCH_COHERE2:
+            {
+                llm = std::make_unique<llm_build_cohere2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_DBRX:
+            {
+                llm = std::make_unique<llm_build_dbrx>(*this, params, gf);
+            } break;
+        case LLM_ARCH_OLMO:
+            {
+                llm = std::make_unique<llm_build_olmo>(*this, params, gf);
+            } break;
+        case LLM_ARCH_OLMO2:
+            {
+                llm = std::make_unique<llm_build_olmo2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_OLMOE:
+            {
+                llm = std::make_unique<llm_build_olmoe>(*this, params, gf);
+            } break;
+        case LLM_ARCH_OPENELM:
+            {
+                llm = std::make_unique<llm_build_openelm>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GPTNEOX:
+            {
+                llm = std::make_unique<llm_build_gptneox>(*this, params, gf);
+            } break;
+        case LLM_ARCH_ARCTIC:
+            {
+                llm = std::make_unique<llm_build_arctic>(*this, params, gf);
+            } break;
+        case LLM_ARCH_DEEPSEEK:
+            {
+                llm = std::make_unique<llm_build_deepseek>(*this, params, gf);
+            } break;
+        case LLM_ARCH_DEEPSEEK2:
+            {
+                llm = std::make_unique<llm_build_deepseek2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_CHATGLM:
+            {
+                llm = std::make_unique<llm_build_chatglm>(*this, params, gf);
+            } break;
+        case LLM_ARCH_BITNET:
+            {
+                llm = std::make_unique<llm_build_bitnet>(*this, params, gf);
+            } break;
+        case LLM_ARCH_T5:
+            {
+                switch (type) {
+                    case LLM_GRAPH_TYPE_ENCODER:
+                        llm = std::make_unique<llm_build_t5_enc>(*this, params, gf);
+                        break;
+                    case LLM_GRAPH_TYPE_DEFAULT:
+                    case LLM_GRAPH_TYPE_DECODER:
+                        llm = std::make_unique<llm_build_t5_dec>(*this, params, gf);
+                        break;
+                    default:
+                        GGML_ABORT("invalid graph type");
+                };
+            } break;
+            //case LLM_ARCH_T5ENCODER:
+            //    {
+            //        llm.build_t5_enc(gf);
+            //    } break;
+        case LLM_ARCH_JAIS:
+            {
+                llm = std::make_unique<llm_build_jais>(*this, params, gf);
+            } break;
+        case LLM_ARCH_NEMOTRON:
+            {
+                llm = std::make_unique<llm_build_nemotron>(*this, params, gf);
+            } break;
+        case LLM_ARCH_EXAONE:
+            {
+                llm = std::make_unique<llm_build_exaone>(*this, params, gf);
+            } break;
+        case LLM_ARCH_RWKV6:
+            {
+                llm = std::make_unique<llm_build_rwkv6>(*this, params, gf);
+            } break;
+        case LLM_ARCH_RWKV6QWEN2:
+            {
+                llm = std::make_unique<llm_build_rwkv6qwen2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_CHAMELEON:
+            {
+                llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
+            } break;
+        case LLM_ARCH_WAVTOKENIZER_DEC:
+            {
+                llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
+            } break;
+        default:
+            GGML_ABORT("fatal error");
+    }
+
+    // add on pooling layer
+    llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b);
+
+    return std::move(llm->res);
+}
+
+//
+// interface implementation
+//
+
+llama_model_params llama_model_default_params() {
+    llama_model_params result = {
+        /*.devices                     =*/ nullptr,
+        /*.n_gpu_layers                =*/ 0,
+        /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
+        /*.main_gpu                    =*/ 0,
+        /*.tensor_split                =*/ nullptr,
+        /*.progress_callback           =*/ nullptr,
+        /*.progress_callback_user_data =*/ nullptr,
+        /*.kv_overrides                =*/ nullptr,
+        /*.vocab_only                  =*/ false,
+        /*.use_mmap                    =*/ true,
+        /*.use_mlock                   =*/ false,
+        /*.check_tensors               =*/ false,
+    };
+
+#ifdef GGML_USE_METAL
+    // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
+    result.n_gpu_layers = 999;
+#endif
+
+    return result;
+}
+
+const llama_vocab * llama_model_get_vocab(const llama_model * model) {
+    return &model->vocab;
+}
+
+void llama_free_model(llama_model * model) {
+    llama_model_free(model);
+}
+
+void llama_model_free(llama_model * model) {
+    delete model;
+}
+
+int32_t llama_model_n_ctx_train(const llama_model * model) {
+    return model->hparams.n_ctx_train;
+}
+
+int32_t llama_model_n_embd(const llama_model * model) {
+    return model->hparams.n_embd;
+}
+
+int32_t llama_model_n_layer(const llama_model * model) {
+    return model->hparams.n_layer;
+}
+
+int32_t llama_model_n_head(const llama_model * model) {
+    return model->hparams.n_head();
+}
+
+int32_t llama_model_n_head_kv(const llama_model * model) {
+    return model->hparams.n_head_kv();
+}
+
+// deprecated
+int32_t llama_n_ctx_train(const llama_model * model) {
+    return llama_model_n_ctx_train(model);
+}
+
+// deprecated
+int32_t llama_n_embd(const llama_model * model) {
+    return llama_model_n_embd(model);
+}
+
+// deprecated
+int32_t llama_n_layer(const llama_model * model) {
+    return llama_model_n_layer(model);
+}
+
+// deprecated
+int32_t llama_n_head(const llama_model * model) {
+    return llama_model_n_head(model);
+}
+
+llama_rope_type llama_model_rope_type(const llama_model * model) {
+    switch (model->arch) {
+        // these models do not use RoPE
+        case LLM_ARCH_GPT2:
+        case LLM_ARCH_GPTJ:
+        case LLM_ARCH_MPT:
+        case LLM_ARCH_REFACT:
+        case LLM_ARCH_BLOOM:
+        case LLM_ARCH_MAMBA:
+        case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_T5:
+        case LLM_ARCH_T5ENCODER:
+        case LLM_ARCH_JAIS:
+        case LLM_ARCH_RWKV6:
+        case LLM_ARCH_RWKV6QWEN2:
+        case LLM_ARCH_WAVTOKENIZER_DEC:
+            return LLAMA_ROPE_TYPE_NONE;
+
+        // use what we call a normal RoPE, operating on pairs of consecutive head values
+        case LLM_ARCH_LLAMA:
+        case LLM_ARCH_DECI:
+        case LLM_ARCH_BAICHUAN:
+        case LLM_ARCH_STARCODER:
+        case LLM_ARCH_PLAMO:
+        case LLM_ARCH_ORION:
+        case LLM_ARCH_INTERNLM2:
+        case LLM_ARCH_MINICPM:
+        case LLM_ARCH_XVERSE:
+        case LLM_ARCH_COMMAND_R:
+        case LLM_ARCH_COHERE2:
+        case LLM_ARCH_OLMO:
+        case LLM_ARCH_ARCTIC:
+        case LLM_ARCH_DEEPSEEK:
+        case LLM_ARCH_DEEPSEEK2:
+        case LLM_ARCH_CHATGLM:
+        case LLM_ARCH_GRANITE:
+        case LLM_ARCH_GRANITE_MOE:
+        case LLM_ARCH_CHAMELEON:
+            return LLAMA_ROPE_TYPE_NORM;
+
+        // the pairs of head values are offset by n_rot/2
+        case LLM_ARCH_FALCON:
+        case LLM_ARCH_GROK:
+        case LLM_ARCH_DBRX:
+        case LLM_ARCH_BERT:
+        case LLM_ARCH_NOMIC_BERT:
+        case LLM_ARCH_STABLELM:
+        case LLM_ARCH_BITNET:
+        case LLM_ARCH_QWEN:
+        case LLM_ARCH_QWEN2:
+        case LLM_ARCH_QWEN2MOE:
+        case LLM_ARCH_OLMO2:
+        case LLM_ARCH_OLMOE:
+        case LLM_ARCH_PHI2:
+        case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
+        case LLM_ARCH_GEMMA:
+        case LLM_ARCH_GEMMA2:
+        case LLM_ARCH_GEMMA3:
+        case LLM_ARCH_STARCODER2:
+        case LLM_ARCH_OPENELM:
+        case LLM_ARCH_GPTNEOX:
+        case LLM_ARCH_CODESHELL:
+        case LLM_ARCH_NEMOTRON:
+        case LLM_ARCH_EXAONE:
+        case LLM_ARCH_MINICPM3:
+            return LLAMA_ROPE_TYPE_NEOX;
+
+        case LLM_ARCH_QWEN2VL:
+            return LLAMA_ROPE_TYPE_MROPE;
+
+        // all model arches should be listed explicitly here
+        case LLM_ARCH_UNKNOWN:
+            GGML_ABORT("unknown architecture");
+    }
+
+    return LLAMA_ROPE_TYPE_NONE;
+}
+
+float llama_model_rope_freq_scale_train(const llama_model * model) {
+    return model->hparams.rope_freq_scale_train;
+}
+
+int32_t llama_model_meta_val_str(const llama_model * model, const char * key, char * buf, size_t buf_size) {
+    const auto & it = model->gguf_kv.find(key);
+    if (it == model->gguf_kv.end()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
+int32_t llama_model_meta_count(const llama_model * model) {
+    return (int)model->gguf_kv.size();
+}
+
+int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)model->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = model->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->first.c_str());
+}
+
+int32_t llama_model_meta_val_str_by_index(const llama_model * model, int32_t i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)model->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = model->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
+int32_t llama_model_desc(const llama_model * model, char * buf, size_t buf_size) {
+    return snprintf(buf, buf_size, "%s", model->desc().c_str());
+}
+
+uint64_t llama_model_size(const llama_model * model) {
+    return model->size();
+}
+
+const char * llama_model_chat_template(const llama_model * model, const char * name) {
+    const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
+        : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
+    const auto & it = model->gguf_kv.find(key);
+    if (it == model->gguf_kv.end()) {
+        return nullptr;
+    }
+
+    return it->second.c_str();
+}
+
+uint64_t llama_model_n_params(const llama_model * model) {
+    return model->n_elements();
+}
+
+bool llama_model_has_encoder(const llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_T5:        return true;
+        case LLM_ARCH_T5ENCODER: return true;
+        default:                 return false;
+    }
+}
+
+bool llama_model_has_decoder(const llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_T5ENCODER: return false;
+        default:                 return true;
+    }
+}
+
+llama_token llama_model_decoder_start_token(const llama_model * model) {
+    return model->hparams.dec_start_token_id;
+}
+
+bool llama_model_is_recurrent(const llama_model * model) {
+    switch (model->arch) {
+        case     LLM_ARCH_MAMBA:      return true;
+        case     LLM_ARCH_RWKV6:      return true;
+        case     LLM_ARCH_RWKV6QWEN2: return true;
+        default: return false;
+    }
+}
+
+const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
+    return model->tensors_by_name;
 }

+ 18 - 1
src/llama-model.h

@@ -2,7 +2,9 @@
 
 #include "llama.h"
 #include "llama-arch.h"
+#include "llama-graph.h"
 #include "llama-hparams.h"
+#include "llama-memory.h"
 #include "llama-vocab.h"
 
 #include <memory>
@@ -10,6 +12,8 @@
 #include <unordered_map>
 #include <vector>
 
+struct llama_cparams;
+struct llama_ubatch;
 struct llama_model_loader;
 
 // available models
@@ -347,7 +351,7 @@ struct llama_model {
     std::string desc() const;
 
     size_t size() const;
-    size_t max_nodes() const;
+    size_t n_tensors() const;
     size_t n_devices() const;
 
     // total number of parameters in the model
@@ -362,9 +366,22 @@ struct llama_model {
 
     const struct ggml_tensor * get_tensor(const char * name) const;
 
+    // TODO: move this to new llm_arch_model_i interface
+    llama_memory_i * create_memory() const; // TODO: params
+
+    // TODO: move this to new llm_arch_model_i interface
+    llm_graph_result_ptr build_graph(
+            const llm_graph_params & params,
+                       ggml_cgraph * gf,
+                    llm_graph_type   type) const;
+
 private:
     struct impl;
     std::unique_ptr<impl> pimpl;
 };
 
 const char * llm_type_name(llm_type type);
+
+// For internal test use
+// TODO: remove
+const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model);

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 0 - 9489
src/llama.cpp


Một số tệp đã không được hiển thị bởi vì quá nhiều tập tin thay đổi trong này khác