|
@@ -9,6 +9,7 @@
|
|
|
#include "llama-model.h"
|
|
#include "llama-model.h"
|
|
|
|
|
|
|
|
#include <cinttypes>
|
|
#include <cinttypes>
|
|
|
|
|
+#include <cmath>
|
|
|
#include <cstring>
|
|
#include <cstring>
|
|
|
#include <limits>
|
|
#include <limits>
|
|
|
#include <stdexcept>
|
|
#include <stdexcept>
|
|
@@ -72,6 +73,43 @@ llama_context::llama_context(
|
|
|
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
|
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (cparams.yarn_ext_factor != 0) {
|
|
|
|
|
+ static auto get_mscale = [](float scale, float mscale) {
|
|
|
|
|
+ return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ const float factor = 1.0f / cparams.rope_freq_scale;
|
|
|
|
|
+
|
|
|
|
|
+ // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
|
|
|
|
|
+ if (hparams.rope_yarn_log_mul != 0.0f) {
|
|
|
|
|
+ // note: here we assume `mscale == 1.0f`
|
|
|
|
|
+ // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
|
|
|
|
|
+ float mscale = 1.0f;
|
|
|
|
|
+ const float mscale_all_dims = hparams.rope_yarn_log_mul;
|
|
|
|
|
+
|
|
|
|
|
+ // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
|
|
|
|
+ // special-case DEEPSEEK v2:
|
|
|
|
|
+ // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
|
|
|
|
|
+ if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
|
|
|
|
|
+ mscale = mscale_all_dims;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
|
|
|
|
|
+
|
|
|
|
|
+ LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
|
|
|
|
|
+ __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
|
|
|
|
|
+ // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
|
|
|
|
|
+ //
|
|
|
|
|
+ // ref: https://github.com/ggml-org/llama.cpp/discussions/7416
|
|
|
|
|
+ // https://github.com/ggml-org/llama.cpp/pull/17945
|
|
|
|
|
+ cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
|
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
|
|
|
|
|
|
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|