Просмотр исходного кода

models : fix the attn_factor for mistral3 graphs + improve consistency (#17945)

* models : fix the attn_factor for mistral3 graphs

* cont : rework attn_factor correction logic

* cont : make deepseek2 consistent

* cont : add TODO

* cont : special-case DSv2

* cont : revert Mistral 3 Large changes

* cont : fix DS2 to use the original attn_factor

* cont : minor comments
Georgi Gerganov 1 месяц назад
Родитель
Сommit
7bed317f53
7 измененных файлов с 78 добавлено и 33 удалено
  1. 8 0
      convert_hf_to_gguf.py
  2. 1 1
      src/llama-graph.cpp
  3. 12 0
      src/llama-hparams.cpp
  4. 8 1
      src/llama-hparams.h
  5. 4 9
      src/llama-kv-cache.cpp
  6. 36 17
      src/llama-model.cpp
  7. 9 5
      src/models/deepseek2.cpp

+ 8 - 0
convert_hf_to_gguf.py

@@ -7286,6 +7286,10 @@ class DeepseekV2Model(TextModel):
             self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
             self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
             self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
+
+            # [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+            # note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
+            # ref https://github.com/ggml-org/llama.cpp/pull/17945
             self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])
 
     _experts: list[dict[str, Tensor]] | None = None
@@ -10041,6 +10045,10 @@ class MistralMoeModel(DeepseekV2Model):
         MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
         yarn_params = self.hparams["yarn"]
         self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"])
+
+        # [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+        # note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
+        # ref https://github.com/ggml-org/llama.cpp/pull/17945
         self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):

+ 1 - 1
src/llama-graph.cpp

@@ -574,7 +574,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     freq_base        (cparams.rope_freq_base),
     freq_scale       (cparams.rope_freq_scale),
     ext_factor       (cparams.yarn_ext_factor),
-    attn_factor      (cparams.yarn_attn_factor),
+    attn_factor      (llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor)),
     beta_fast        (cparams.yarn_beta_fast),
     beta_slow        (cparams.yarn_beta_slow),
     norm_eps         (hparams.f_norm_eps),

+ 12 - 0
src/llama-hparams.cpp

@@ -1,7 +1,9 @@
 #include "llama-hparams.h"
 
 #include "ggml.h"
+
 #include <cassert>
+#include <cmath>
 
 void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
     if (dense_first) {
@@ -229,3 +231,13 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama
 
     return false;
 }
+
+float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) {
+    GGML_ASSERT(ext_factor >= 0.0f);
+
+    if (ext_factor != 0.0f) {
+        attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
+    }
+
+    return attn_factor;
+}

+ 8 - 1
src/llama-hparams.h

@@ -107,6 +107,7 @@ struct llama_hparams {
     float    rope_freq_base_train_swa;
     float    rope_freq_scale_train;
     float    rope_freq_scale_train_swa;
+
     uint32_t n_ctx_orig_yarn;
     float    rope_yarn_log_mul = 0.0f;
 
@@ -267,7 +268,13 @@ struct llama_hparams {
     // TODO: think of a better place for this function
     // TODO: pack the SWA params in a struct?
     static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
+
+    // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
+    // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
+    //
+    // ref: https://github.com/ggml-org/llama.cpp/discussions/7416
+    //      https://github.com/ggml-org/llama.cpp/pull/17945
+    static float yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor);
 };
 
 static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
-

+ 4 - 9
src/llama-kv-cache.cpp

@@ -1369,9 +1369,10 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
                       float   freq_scale) const {
     const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
 
-    const auto & yarn_ext_factor = cparams.yarn_ext_factor;
-    const auto & yarn_beta_fast  = cparams.yarn_beta_fast;
-    const auto & yarn_beta_slow  = cparams.yarn_beta_slow;
+    const auto & yarn_ext_factor  = cparams.yarn_ext_factor;
+    const auto & yarn_beta_fast   = cparams.yarn_beta_fast;
+    const auto & yarn_beta_slow   = cparams.yarn_beta_slow;
+    const auto & yarn_attn_factor = llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor);
 
     const auto & n_rot     = hparams.n_rot;
     const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
@@ -1382,12 +1383,6 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
                                 ? LLAMA_ROPE_TYPE_NEOX
                                 : hparams.rope_type;
 
-    // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
-    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
-    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
-                                    ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
-                                    : cparams.yarn_attn_factor;
-
     ggml_tensor * tmp;
 
     if (ggml_is_quantized(cur->type)) {

+ 36 - 17
src/llama-model.cpp

@@ -1635,7 +1635,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     // that have no expert_gating_func model parameter set
                     hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
                 }
-                ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
+
+                if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) {
+                    // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+                    // cancel the factor from the convert script
+                    hparams.rope_yarn_log_mul /= 0.1f;
+                }
 
                 // (optional) temperature tuning - used by mistral-large
                 ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE,  hparams.f_attn_temp_scale,       false);
@@ -2267,9 +2272,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
 
-                ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST,   hparams.yarn_beta_fast, false);
-                ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,   hparams.yarn_beta_slow, false);
-                ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL,     hparams.rope_yarn_log_mul, false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast,    false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow,    false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL,   hparams.rope_yarn_log_mul, 0.0f);
 
                 // TODO: maybe add n_attn_temp_floor_scale as a separate KV?
                 if (hparams.f_attn_temp_scale != 0.0f) {
@@ -2279,18 +2284,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     }
                 }
 
-                // TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f
-                //       but may need further verification with other values
-                if (hparams.rope_yarn_log_mul != 0.0f) {
-                    float factor = 1.0f / hparams.rope_freq_scale_train;
-                    float mscale = 1.0f;
-                    float mscale_all_dims = hparams.rope_yarn_log_mul;
-                    static auto get_mscale = [](float scale, float mscale) {
-                        return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
-                    };
-                    hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
-                }
-
                 switch (hparams.n_layer) {
                     case 26: type = LLM_TYPE_3B; break;
                     case 34: type = LLM_TYPE_8B; break;
@@ -2301,6 +2294,32 @@ void llama_model::load_hparams(llama_model_loader & ml) {
         default: throw std::runtime_error("unsupported model architecture");
     }
 
+    // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
+    if (hparams.rope_yarn_log_mul != 0.0f) {
+        const float factor = 1.0f / hparams.rope_freq_scale_train;
+
+        // note: here we assume `mscale == 1.0f`
+        // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
+              float mscale          = 1.0f;
+        const float mscale_all_dims = hparams.rope_yarn_log_mul;
+
+        // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+        // special-case DEEPSEEK v2:
+        // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
+        if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
+            mscale = mscale_all_dims;
+        }
+
+        static auto get_mscale = [](float scale, float mscale) {
+            return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
+        };
+
+        hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
+
+        LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
+                __func__, hparams.yarn_attn_factor, mscale, mscale_all_dims);
+    }
+
     pimpl->n_bytes = ml.n_bytes;
 
     pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();
@@ -6806,6 +6825,7 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
         LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
         LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
+        LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n",   __func__, hparams.rope_yarn_log_mul);
         LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
         // MRoPE (Multi-axis Rotary Position Embedding) sections
         if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
@@ -6869,7 +6889,6 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
         LLAMA_LOG_INFO("%s: expert_weights_norm  = %d\n",     __func__, hparams.expert_weights_norm);
         LLAMA_LOG_INFO("%s: expert_gating_func   = %s\n",     __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
-        LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
     }
 
     if (arch == LLM_ARCH_QWEN2MOE) {

+ 9 - 5
src/models/deepseek2.cpp

@@ -1,7 +1,5 @@
 #include "models.h"
 
-
-
 llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
     // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
@@ -20,9 +18,15 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
 
     // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
     // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
-    const float mscale      = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
-    const float kq_scale    = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k));
-    const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
+    // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+
+    // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor
+    GGML_ASSERT(ext_factor >= 0.0f);
+    const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale));
+
+    // use the original attn_factor to pre-scale the kq_scale
+    const float mscale   = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
+    const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k));
 
     ggml_tensor * cur;
     ggml_tensor * inpL;