Quellcode durchsuchen

models : fix YaRN regression + consolidate logic (#18006)

* models : fix YaRN regression + consolidate logic

* cont : fix the fix

* cont : remove header

* cont : add header
Georgi Gerganov vor 1 Monat
Ursprung
Commit
609a2d0268
6 geänderte Dateien mit 40 neuen und 46 gelöschten Zeilen
  1. 38 0
      src/llama-context.cpp
  2. 1 1
      src/llama-graph.cpp
  3. 0 11
      src/llama-hparams.cpp
  4. 0 7
      src/llama-hparams.h
  5. 1 1
      src/llama-kv-cache.cpp
  6. 0 26
      src/llama-model.cpp

+ 38 - 0
src/llama-context.cpp

@@ -9,6 +9,7 @@
 #include "llama-model.h"
 #include "llama-model.h"
 
 
 #include <cinttypes>
 #include <cinttypes>
+#include <cmath>
 #include <cstring>
 #include <cstring>
 #include <limits>
 #include <limits>
 #include <stdexcept>
 #include <stdexcept>
@@ -72,6 +73,43 @@ llama_context::llama_context(
         cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
         cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
     }
     }
 
 
+    if (cparams.yarn_ext_factor != 0) {
+        static auto get_mscale = [](float scale, float mscale) {
+            return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
+        };
+
+        const float factor = 1.0f / cparams.rope_freq_scale;
+
+        // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
+        if (hparams.rope_yarn_log_mul != 0.0f) {
+            // note: here we assume `mscale == 1.0f`
+            // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
+                  float mscale          = 1.0f;
+            const float mscale_all_dims = hparams.rope_yarn_log_mul;
+
+            // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+            // special-case DEEPSEEK v2:
+            // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
+            if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
+                mscale = mscale_all_dims;
+            }
+
+            cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
+
+            LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
+                    __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
+        } else {
+            cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
+        }
+
+        // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
+        // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
+        //
+        // ref: https://github.com/ggml-org/llama.cpp/discussions/7416
+        //      https://github.com/ggml-org/llama.cpp/pull/17945
+        cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
+    }
+
     cparams.yarn_attn_factor *= hparams.rope_attn_factor;
     cparams.yarn_attn_factor *= hparams.rope_attn_factor;
 
 
     if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
     if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {

+ 1 - 1
src/llama-graph.cpp

@@ -574,7 +574,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     freq_base        (cparams.rope_freq_base),
     freq_base        (cparams.rope_freq_base),
     freq_scale       (cparams.rope_freq_scale),
     freq_scale       (cparams.rope_freq_scale),
     ext_factor       (cparams.yarn_ext_factor),
     ext_factor       (cparams.yarn_ext_factor),
-    attn_factor      (llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor)),
+    attn_factor      (cparams.yarn_attn_factor),
     beta_fast        (cparams.yarn_beta_fast),
     beta_fast        (cparams.yarn_beta_fast),
     beta_slow        (cparams.yarn_beta_slow),
     beta_slow        (cparams.yarn_beta_slow),
     norm_eps         (hparams.f_norm_eps),
     norm_eps         (hparams.f_norm_eps),

+ 0 - 11
src/llama-hparams.cpp

@@ -3,7 +3,6 @@
 #include "ggml.h"
 #include "ggml.h"
 
 
 #include <cassert>
 #include <cassert>
-#include <cmath>
 
 
 void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
 void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
     if (dense_first) {
     if (dense_first) {
@@ -231,13 +230,3 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama
 
 
     return false;
     return false;
 }
 }
-
-float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) {
-    GGML_ASSERT(ext_factor >= 0.0f);
-
-    if (ext_factor != 0.0f) {
-        attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
-    }
-
-    return attn_factor;
-}

+ 0 - 7
src/llama-hparams.h

@@ -268,13 +268,6 @@ struct llama_hparams {
     // TODO: think of a better place for this function
     // TODO: think of a better place for this function
     // TODO: pack the SWA params in a struct?
     // TODO: pack the SWA params in a struct?
     static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
     static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
-
-    // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
-    // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
-    //
-    // ref: https://github.com/ggml-org/llama.cpp/discussions/7416
-    //      https://github.com/ggml-org/llama.cpp/pull/17945
-    static float yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor);
 };
 };
 
 
 static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
 static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

+ 1 - 1
src/llama-kv-cache.cpp

@@ -1372,7 +1372,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
     const auto & yarn_ext_factor  = cparams.yarn_ext_factor;
     const auto & yarn_ext_factor  = cparams.yarn_ext_factor;
     const auto & yarn_beta_fast   = cparams.yarn_beta_fast;
     const auto & yarn_beta_fast   = cparams.yarn_beta_fast;
     const auto & yarn_beta_slow   = cparams.yarn_beta_slow;
     const auto & yarn_beta_slow   = cparams.yarn_beta_slow;
-    const auto & yarn_attn_factor = llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor);
+    const auto & yarn_attn_factor = cparams.yarn_attn_factor;
 
 
     const auto & n_rot     = hparams.n_rot;
     const auto & n_rot     = hparams.n_rot;
     const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
     const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE

+ 0 - 26
src/llama-model.cpp

@@ -2294,32 +2294,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
         default: throw std::runtime_error("unsupported model architecture");
         default: throw std::runtime_error("unsupported model architecture");
     }
     }
 
 
-    // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
-    if (hparams.rope_yarn_log_mul != 0.0f) {
-        const float factor = 1.0f / hparams.rope_freq_scale_train;
-
-        // note: here we assume `mscale == 1.0f`
-        // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
-              float mscale          = 1.0f;
-        const float mscale_all_dims = hparams.rope_yarn_log_mul;
-
-        // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
-        // special-case DEEPSEEK v2:
-        // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
-        if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
-            mscale = mscale_all_dims;
-        }
-
-        static auto get_mscale = [](float scale, float mscale) {
-            return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
-        };
-
-        hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
-
-        LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
-                __func__, hparams.yarn_attn_factor, mscale, mscale_all_dims);
-    }
-
     pimpl->n_bytes = ml.n_bytes;
     pimpl->n_bytes = ml.n_bytes;
 
 
     pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();
     pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();