Procházet zdrojové kódy

model : support Rnj-1 (#17811)

* add support for rnj1

* refactor gemma3 to support rnj-1

* address review comments
philip-essential před 1 měsícem
rodič
revize
1d2a1ab73d
5 změnil soubory, kde provedl 76 přidání a 24 odebrání
  1. 25 10
      convert_hf_to_gguf.py
  2. 1 1
      src/CMakeLists.txt
  3. 17 6
      src/llama-model.cpp
  4. 30 5
      src/models/gemma3.cpp
  5. 3 2
      src/models/models.h

+ 25 - 10
convert_hf_to_gguf.py

@@ -5825,9 +5825,11 @@ class Gemma3Model(TextModel):
     norm_shift = 1.0  # Gemma3RMSNorm adds 1.0 to the norm value
 
     def set_vocab(self):
-        self._set_vocab_sentencepiece()
-
-        self.gguf_writer.add_add_space_prefix(False)
+        if (self.dir_model / "tokenizer.model").is_file():
+            self._set_vocab_sentencepiece()
+            self.gguf_writer.add_add_space_prefix(False)
+        else:
+            self._set_vocab_gpt2()
 
     def set_gguf_parameters(self):
         hparams = self.hparams
@@ -5845,13 +5847,24 @@ class Gemma3Model(TextModel):
         self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
         # attn_logit_softcapping is removed in Gemma3
         assert hparams.get("attn_logit_softcapping") is None
-        self.gguf_writer.add_sliding_window(hparams["sliding_window"])
+        if (final_logit_softcap := hparams.get("final_logit_softcapping")):
+            self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
+        if hparams.get("sliding_window_pattern") != 1:
+            self.gguf_writer.add_sliding_window(hparams["sliding_window"])
         self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
         if hparams.get("rope_scaling") is not None:
-            assert hparams["rope_scaling"]["rope_type"] == "linear"
-            # important: this rope_scaling is only applied for global layers, and not used by 1B model
-            self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
-            self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
+            rope_scaling = hparams["rope_scaling"]
+            if rope_scaling["rope_type"] == "linear":
+                # important: this rope_scaling is only applied for global layers, and not used by 1B model
+                self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+                self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
+            elif rope_scaling["rope_type"] == "yarn":
+                self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
+                self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
+                self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
+                self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"])
+                self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"])
+                self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"])
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         del bid  # unused
@@ -5865,8 +5878,10 @@ class Gemma3Model(TextModel):
 
         # remove OOV (out-of-vocabulary) rows in token_embd
         if "embed_tokens.weight" in name:
-            vocab = self._create_vocab_sentencepiece()
-            tokens = vocab[0]
+            if (self.dir_model / "tokenizer.model").is_file():
+                tokens = self._create_vocab_sentencepiece()[0]
+            else:
+                tokens = self.get_vocab_base()[0]
             data_torch = data_torch[:len(tokens)]
 
         # ref code in Gemma3RMSNorm

+ 1 - 1
src/CMakeLists.txt

@@ -67,7 +67,7 @@ add_library(llama
             models/gemma-embedding.cpp
             models/gemma.cpp
             models/gemma2-iswa.cpp
-            models/gemma3-iswa.cpp
+            models/gemma3.cpp
             models/gemma3n-iswa.cpp
             models/glm4-moe.cpp
             models/glm4.cpp

+ 17 - 6
src/llama-model.cpp

@@ -1264,18 +1264,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_GEMMA3:
             {
-                hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
-                hparams.set_swa_pattern(6);
+                const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
+                if (found_swa && hparams.n_swa > 0) {
+                    hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+                    hparams.set_swa_pattern(6);
 
-                hparams.rope_freq_base_train_swa  = 10000.0f;
-                hparams.rope_freq_scale_train_swa = 1.0f;
+                    hparams.rope_freq_base_train_swa  = 10000.0f;
+                    hparams.rope_freq_scale_train_swa = 1.0f;
+                } else {
+                    hparams.swa_type = LLAMA_SWA_TYPE_NONE;
+                }
 
-                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
+                hparams.f_final_logit_softcapping = 0.0f;
+                ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
                 switch (hparams.n_layer) {
                     case 18: type = LLM_TYPE_270M; break;
                     case 26: type = LLM_TYPE_1B; break;
+                    case 32: type = LLM_TYPE_8B; break; // Rnj-1
                     case 34: type = LLM_TYPE_4B; break;
                     case 48: type = LLM_TYPE_12B; break;
                     case 62: type = LLM_TYPE_27B; break;
@@ -7304,7 +7311,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             } break;
         case LLM_ARCH_GEMMA3:
             {
-                llm = std::make_unique<llm_build_gemma3_iswa>(*this, params);
+                if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
+                    llm = std::make_unique<llm_build_gemma3<true>>(*this, params);
+                } else {
+                    llm = std::make_unique<llm_build_gemma3<false>>(*this, params);
+                }
             } break;
         case LLM_ARCH_GEMMA3N:
             {

+ 30 - 5
src/models/gemma3-iswa.cpp → src/models/gemma3.cpp

@@ -1,6 +1,7 @@
 #include "models.h"
 
-llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+template <bool iswa>
+llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     const int64_t n_embd_head = hparams.n_embd_head_k;
 
     ggml_tensor * cur;
@@ -17,13 +18,28 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
     ggml_tensor * inp_pos = build_inp_pos();
 
     // TODO: is causal == true correct? might need some changes
-    auto * inp_attn = build_attn_inp_kv_iswa();
+    using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
+    inp_attn_type * inp_attn = nullptr;
+
+    if constexpr (iswa) {
+        inp_attn = build_attn_inp_kv_iswa();
+    } else {
+        inp_attn = build_attn_inp_kv();
+    }
 
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     for (int il = 0; il < n_layer; ++il) {
-        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
-        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+        float freq_base_l  = 0.0f;
+        float freq_scale_l = 0.0f;
+
+        if constexpr (iswa) {
+            freq_base_l  = model.get_rope_freq_base (cparams, il);
+            freq_scale_l = model.get_rope_freq_scale(cparams, il);
+        } else {
+            freq_base_l  = freq_base;
+            freq_scale_l = freq_scale;
+        }
 
         // norm
         cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
@@ -102,7 +118,7 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
         cur = build_norm(cur,
                 model.layers[il].ffn_post_norm, NULL,
                 LLM_NORM_RMS, -1);
-        cb(cur, "ffn_post_norm", -1);
+        cb(cur, "ffn_post_norm", il);
 
         cur = ggml_add(ctx0, cur, sa_out);
 
@@ -124,8 +140,17 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
     // lm_head
     cur = build_lora_mm(model.output, cur);
 
+    if (hparams.f_final_logit_softcapping) {
+        cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+        cur = ggml_tanh(ctx0, cur);
+        cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+    }
+
     cb(cur, "result_output", -1);
     res->t_logits = cur;
 
     ggml_build_forward_expand(gf, cur);
 }
+
+template struct llm_build_gemma3<false>;
+template struct llm_build_gemma3<true>;

+ 3 - 2
src/models/models.h

@@ -179,8 +179,9 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
     llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_gemma3_iswa : public llm_graph_context {
-    llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params);
+template <bool iswa>
+struct llm_build_gemma3 : public llm_graph_context {
+    llm_build_gemma3(const llama_model & model, const llm_graph_params & params);
 };
 
 struct llm_build_gemma3n_iswa : public llm_graph_context {