Browse Source

hparams : add n_embd_inp() to support extended embed (#16928)

* add n_embd_full to support extended embed

* don't change output

* rename to n_embd_inp

* restore n_embd where applicable
Sigbjørn Skjæret 2 months ago
parent
commit
9008027aa3

+ 1 - 0
include/llama.h

@@ -486,6 +486,7 @@ extern "C" {
 
 
     LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_embd     (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_embd     (const struct llama_model * model);
+    LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_layer    (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_layer    (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head     (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head     (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head_kv  (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head_kv  (const struct llama_model * model);

+ 3 - 3
src/llama-context.cpp

@@ -827,7 +827,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
 
 
     const auto & hparams = model.hparams;
     const auto & hparams = model.hparams;
 
 
-    const int64_t n_embd  = hparams.n_embd;
+    const int64_t n_embd  = hparams.n_embd_inp();
     const int64_t n_vocab = model.vocab.n_tokens();
     const int64_t n_vocab = model.vocab.n_tokens();
 
 
     // note: during encode, we always pass the full sequence starting from pos = 0
     // note: during encode, we always pass the full sequence starting from pos = 0
@@ -996,7 +996,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
     const auto & hparams = model.hparams;
     const auto & hparams = model.hparams;
 
 
     const int64_t n_vocab = vocab.n_tokens();
     const int64_t n_vocab = vocab.n_tokens();
-    const int64_t n_embd  = hparams.n_embd;
+    const int64_t n_embd  = hparams.n_embd_inp();
 
 
     // when computing embeddings, all tokens are output
     // when computing embeddings, all tokens are output
     const bool output_all = cparams.embeddings;
     const bool output_all = cparams.embeddings;
@@ -2154,7 +2154,7 @@ void llama_context::opt_epoch_iter(
             batch.logits  [pos_batch]    = true;
             batch.logits  [pos_batch]    = true;
         }
         }
 
 
-        if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
+        if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
             LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
             LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
             return;
             return;
         }
         }

+ 2 - 2
src/llama-graph.cpp

@@ -1142,7 +1142,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
 
 
 // input embeddings with optional lora
 // input embeddings with optional lora
 ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
 ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
-    const int64_t n_embd = hparams.n_embd;
+    const int64_t n_embd = hparams.n_embd_inp();
 
 
     auto inp = std::make_unique<llm_graph_input_embd>();
     auto inp = std::make_unique<llm_graph_input_embd>();
 
 
@@ -1279,7 +1279,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
     //    return cur;
     //    return cur;
     //}
     //}
 
 
-    const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
+    const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
     const auto n_enc  = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
     const auto n_enc  = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
 
 
     cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
     cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);

+ 10 - 0
src/llama-hparams.cpp

@@ -60,6 +60,16 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const {
     return n_head/n_head_kv;
     return n_head/n_head_kv;
 }
 }
 
 
+uint32_t llama_hparams::n_embd_inp() const {
+    uint32_t n_embd_inp = n_embd;
+
+    if (n_deepstack_layers > 0) {
+        n_embd_inp += n_embd * n_deepstack_layers;
+    }
+
+    return n_embd_inp;
+}
+
 uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
 uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
     const uint32_t n_head_kv = this->n_head_kv(il);
     const uint32_t n_head_kv = this->n_head_kv(il);
 
 

+ 3 - 0
src/llama-hparams.h

@@ -227,6 +227,9 @@ struct llama_hparams {
 
 
     uint32_t n_gqa(uint32_t il = 0) const;
     uint32_t n_gqa(uint32_t il = 0) const;
 
 
+    // dimension of main + auxiliary input embeddings
+    uint32_t n_embd_inp() const;
+
     // dimension of key embeddings across all k-v heads
     // dimension of key embeddings across all k-v heads
     uint32_t n_embd_k_gqa(uint32_t il = 0) const;
     uint32_t n_embd_k_gqa(uint32_t il = 0) const;
 
 

+ 7 - 16
src/llama-model.cpp

@@ -276,8 +276,8 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
             } break;
             } break;
         case GGML_OP_IM2COL:
         case GGML_OP_IM2COL:
             {
             {
-                const int n_embd = hparams.n_embd;
-                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
+                const int n_embd_inp = hparams.n_embd_inp();
+                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd_inp, w->ne[1], 1, 1);
                 op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
                 op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
             } break;
             } break;
         case GGML_OP_SCALE:
         case GGML_OP_SCALE:
@@ -1039,9 +1039,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     case 64: type = LLM_TYPE_32B; break;
                     case 64: type = LLM_TYPE_32B; break;
                     default: type = LLM_TYPE_UNKNOWN;
                     default: type = LLM_TYPE_UNKNOWN;
                 }
                 }
-                // since vision model stacks deepstack features along feature dim
-                // we also create a fake "n_embd" for text model to be the main embd + deepstack embds
-                hparams.n_embd *= hparams.n_deepstack_layers + 1;
             } break;
             } break;
         case LLM_ARCH_QWEN3MOE:
         case LLM_ARCH_QWEN3MOE:
             {
             {
@@ -1065,9 +1062,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     case 94: type = LLM_TYPE_235B_A22B; break;
                     case 94: type = LLM_TYPE_235B_A22B; break;
                     default: type = LLM_TYPE_UNKNOWN;
                     default: type = LLM_TYPE_UNKNOWN;
                 }
                 }
-                // since vision model stacks deepstack features along feature dim
-                // we also create a fake "n_embd" for text model to be the main embd + deepstack embds
-                hparams.n_embd *= hparams.n_deepstack_layers + 1;
             } break;
             } break;
         case LLM_ARCH_PHI2:
         case LLM_ARCH_PHI2:
             {
             {
@@ -3341,10 +3335,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             case LLM_ARCH_QWEN3:
             case LLM_ARCH_QWEN3:
             case LLM_ARCH_QWEN3VL:
             case LLM_ARCH_QWEN3VL:
                 {
                 {
-                    // for model loading, the weights only have the main embd
-                    // so we need to divide by the number of deepstack layers + 1
-                    // n_embd is const int so we declare a new variable
-                    int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1);
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
 
                     // output
                     // output
@@ -3380,10 +3370,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             case LLM_ARCH_QWEN3MOE:
             case LLM_ARCH_QWEN3MOE:
             case LLM_ARCH_QWEN3VLMOE:
             case LLM_ARCH_QWEN3VLMOE:
                 {
                 {
-                    // for model loading, the weights only have the main embd
-                    // so we need to divide by the number of deepstack layers + 1
-                    // n_embd is const int so we declare a new variable
-                    int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1);
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
 
                     // output
                     // output
@@ -6535,6 +6521,7 @@ void llama_model::print_info() const {
     if (!hparams.vocab_only) {
     if (!hparams.vocab_only) {
         LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
         LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
         LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
         LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
+        LLAMA_LOG_INFO("%s: n_embd_inp       = %u\n",     __func__, hparams.n_embd_inp());
         LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
         LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
         LLAMA_LOG_INFO("%s: n_head           = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head(il);    }, hparams.n_layer).c_str());
         LLAMA_LOG_INFO("%s: n_head           = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head(il);    }, hparams.n_layer).c_str());
         LLAMA_LOG_INFO("%s: n_head_kv        = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
         LLAMA_LOG_INFO("%s: n_head_kv        = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
@@ -7380,6 +7367,10 @@ int32_t llama_model_n_embd(const llama_model * model) {
     return model->hparams.n_embd;
     return model->hparams.n_embd;
 }
 }
 
 
+int32_t llama_model_n_embd_inp(const llama_model * model) {
+    return model->hparams.n_embd_inp();
+}
+
 int32_t llama_model_n_layer(const llama_model * model) {
 int32_t llama_model_n_layer(const llama_model * model) {
     return model->hparams.n_layer;
     return model->hparams.n_layer;
 }
 }

+ 1 - 2
src/models/qwen3vl-moe.cpp

@@ -1,9 +1,8 @@
 #include "models.h"
 #include "models.h"
 
 
 llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
 llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds
     const size_t n_deepstack_layers = hparams.n_deepstack_layers;
     const size_t n_deepstack_layers = hparams.n_deepstack_layers;
-    const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1);
+    const int64_t n_embd = hparams.n_embd;
     const int64_t n_embd_head = hparams.n_embd_head_v;
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
 
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);

+ 1 - 4
src/models/qwen3vl.cpp

@@ -1,13 +1,10 @@
 #include "models.h"
 #include "models.h"
 
 
 llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
 llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-
-    const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds
     const size_t n_deepstack_layers = hparams.n_deepstack_layers;
     const size_t n_deepstack_layers = hparams.n_deepstack_layers;
-    const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1);
+    const int64_t n_embd = hparams.n_embd;
     const int64_t n_embd_head = hparams.n_embd_head_v;
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
 
-
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
     GGML_ASSERT(n_embd_head == hparams.n_rot);
     GGML_ASSERT(n_embd_head == hparams.n_rot);
 
 

+ 1 - 1
tools/mtmd/mtmd.cpp

@@ -163,7 +163,7 @@ struct mtmd_context {
         print_timings(ctx_params.print_timings),
         print_timings(ctx_params.print_timings),
         n_threads    (ctx_params.n_threads),
         n_threads    (ctx_params.n_threads),
         media_marker (ctx_params.media_marker),
         media_marker (ctx_params.media_marker),
-        n_embd_text  (llama_model_n_embd(text_model))
+        n_embd_text  (llama_model_n_embd_inp(text_model))
     {
     {
         if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) {
         if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) {
             throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
             throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");