Explorar o código

llama : fix KV shift for qwen2vl (#13870)

* llama : fix KV shift for qwen2vl

* add ref to the PR
Xuan-Son Nguyen hai 7 meses
pai
achega
763d06edb7
Modificáronse 2 ficheiros con 11 adicións e 3 borrados
  1. 1 1
      src/llama-graph.cpp
  2. 10 2
      src/llama-kv-cache.cpp

+ 1 - 1
src/llama-graph.cpp

@@ -455,7 +455,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     }
     }
 
 
 int64_t llm_graph_context::n_pos_per_embd() const {
 int64_t llm_graph_context::n_pos_per_embd() const {
-    return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
+    return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
 }
 }
 
 
 void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
 void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {

+ 10 - 2
src/llama-kv-cache.cpp

@@ -757,11 +757,19 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
     const auto & yarn_beta_slow  = cparams.yarn_beta_slow;
     const auto & yarn_beta_slow  = cparams.yarn_beta_slow;
 
 
     const auto & n_rot     = hparams.n_rot;
     const auto & n_rot     = hparams.n_rot;
-    const auto & rope_type = hparams.rope_type;
+    const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
+                                // @ngxson : this is a workaround
+                                // for M-RoPE, we want to rotate the whole vector when doing KV shift
+                                // a normal RoPE should work, we just need to use the correct ordering
+                                // ref: https://github.com/ggml-org/llama.cpp/pull/13870
+                                ? LLAMA_ROPE_TYPE_NEOX
+                                : hparams.rope_type;
 
 
     // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
     // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
     // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
     // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
-    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
+    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
+                                    ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
+                                    : cparams.yarn_attn_factor;
 
 
     ggml_tensor * tmp;
     ggml_tensor * tmp;