Explorar el Código

llama : implement YaRN RoPE scaling (#2268)

Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
Co-authored-by: Jeffrey Quesnelle <jquesnelle@gmail.com>
cebtenzzre hace 2 años
padre
commit
898aeca90a

+ 66 - 13
common/common.cpp

@@ -219,12 +219,52 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
                 break;
             }
             }
             params.rope_freq_scale = std::stof(argv[i]);
             params.rope_freq_scale = std::stof(argv[i]);
+        } else if (arg == "--rope-scaling") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            std::string value(argv[i]);
+            /**/ if (value == "none")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
+            else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
+            else if (value == "yarn")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
+            else { invalid_param = true; break; }
         } else if (arg == "--rope-scale") {
         } else if (arg == "--rope-scale") {
             if (++i >= argc) {
             if (++i >= argc) {
                 invalid_param = true;
                 invalid_param = true;
                 break;
                 break;
             }
             }
             params.rope_freq_scale = 1.0f/std::stof(argv[i]);
             params.rope_freq_scale = 1.0f/std::stof(argv[i]);
+        } else if (arg == "--yarn-orig-ctx") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_orig_ctx = std::stoi(argv[i]);
+        } else if (arg == "--yarn-ext-factor") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_ext_factor = std::stof(argv[i]);
+        } else if (arg == "--yarn-attn-factor") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_attn_factor = std::stof(argv[i]);
+        } else if (arg == "--yarn-beta-fast") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_beta_fast = std::stof(argv[i]);
+        } else if (arg == "--yarn-beta-slow") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_beta_slow = std::stof(argv[i]);
         } else if (arg == "--memory-f32") {
         } else if (arg == "--memory-f32") {
             params.memory_f16 = false;
             params.memory_f16 = false;
         } else if (arg == "--top-p") {
         } else if (arg == "--top-p") {
@@ -716,9 +756,16 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --cfg-negative-prompt-file FNAME\n");
     printf("  --cfg-negative-prompt-file FNAME\n");
     printf("                        negative prompt file to use for guidance. (default: empty)\n");
     printf("                        negative prompt file to use for guidance. (default: empty)\n");
     printf("  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
     printf("  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
-    printf("  --rope-scale N        RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
+    printf("  --rope-scaling {none,linear,yarn}\n");
+    printf("                        RoPE frequency scaling method, defaults to linear unless specified by the model\n");
+    printf("  --rope-scale N        RoPE context scaling factor, expands context by a factor of N\n");
     printf("  --rope-freq-base N    RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
     printf("  --rope-freq-base N    RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
-    printf("  --rope-freq-scale N   RoPE frequency linear scaling factor (default: loaded from model)\n");
+    printf("  --rope-freq-scale N   RoPE frequency scaling factor, expands context by a factor of 1/N\n");
+    printf("  --yarn-orig-ctx N     YaRN: original context size of model (default: 0 = model training context size)\n");
+    printf("  --yarn-ext-factor N   YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
+    printf("  --yarn-attn-factor N  YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
+    printf("  --yarn-beta-slow N    YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
+    printf("  --yarn-beta-fast N    YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
     printf("  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
     printf("  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
     printf("  --no-penalize-nl      do not penalize newline token\n");
     printf("  --no-penalize-nl      do not penalize newline token\n");
     printf("  --memory-f32          use f32 instead of f16 for memory key+value (default: disabled)\n");
     printf("  --memory-f32          use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -826,17 +873,23 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
 struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
 struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
     auto cparams = llama_context_default_params();
     auto cparams = llama_context_default_params();
 
 
-    cparams.n_ctx           = params.n_ctx;
-    cparams.n_batch         = params.n_batch;
-    cparams.n_threads       = params.n_threads;
-    cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
-    cparams.mul_mat_q       = params.mul_mat_q;
-    cparams.seed            = params.seed;
-    cparams.f16_kv          = params.memory_f16;
-    cparams.logits_all      = params.logits_all;
-    cparams.embedding       = params.embedding;
-    cparams.rope_freq_base  = params.rope_freq_base;
-    cparams.rope_freq_scale = params.rope_freq_scale;
+    cparams.n_ctx             = params.n_ctx;
+    cparams.n_batch           = params.n_batch;
+    cparams.n_threads         = params.n_threads;
+    cparams.n_threads_batch   = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+    cparams.mul_mat_q         = params.mul_mat_q;
+    cparams.seed              = params.seed;
+    cparams.f16_kv            = params.memory_f16;
+    cparams.logits_all        = params.logits_all;
+    cparams.embedding         = params.embedding;
+    cparams.rope_scaling_type = params.rope_scaling_type;
+    cparams.rope_freq_base    = params.rope_freq_base;
+    cparams.rope_freq_scale   = params.rope_freq_scale;
+    cparams.yarn_ext_factor   = params.yarn_ext_factor;
+    cparams.yarn_attn_factor  = params.yarn_attn_factor;
+    cparams.yarn_beta_fast    = params.yarn_beta_fast;
+    cparams.yarn_beta_slow    = params.yarn_beta_slow;
+    cparams.yarn_orig_ctx     = params.yarn_orig_ctx;
 
 
     return cparams;
     return cparams;
 }
 }

+ 7 - 0
common/common.h

@@ -9,6 +9,7 @@
 #define LOG_NO_FILE_LINE_FUNCTION
 #define LOG_NO_FILE_LINE_FUNCTION
 #include "log.h"
 #include "log.h"
 
 
+#include <cmath>
 #include <string>
 #include <string>
 #include <vector>
 #include <vector>
 #include <random>
 #include <random>
@@ -54,6 +55,12 @@ struct gpt_params {
     int32_t n_beams                         = 0;    // if non-zero then use beam search of given width.
     int32_t n_beams                         = 0;    // if non-zero then use beam search of given width.
     float   rope_freq_base                  = 0.0f; // RoPE base frequency
     float   rope_freq_base                  = 0.0f; // RoPE base frequency
     float   rope_freq_scale                 = 0.0f; // RoPE frequency scaling factor
     float   rope_freq_scale                 = 0.0f; // RoPE frequency scaling factor
+    float   yarn_ext_factor                 = NAN;  // YaRN extrapolation mix factor
+    float   yarn_attn_factor                = 1.0f; // YaRN magnitude scaling factor
+    float   yarn_beta_fast                  = 32.0f;// YaRN low correction dim
+    float   yarn_beta_slow                  = 1.0f; // YaRN high correction dim
+    int32_t yarn_orig_ctx                   = 0;    // YaRN original context length
+    int8_t  rope_scaling_type               = LLAMA_ROPE_SCALING_UNSPECIFIED;
 
 
     // // sampling parameters
     // // sampling parameters
     struct llama_sampling_params sparams;
     struct llama_sampling_params sparams;

+ 2 - 1
convert-baichuan-hf-to-gguf.py

@@ -163,7 +163,8 @@ gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
 if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
 if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
     if "type" in hparams["rope_scaling"]:
     if "type" in hparams["rope_scaling"]:
         if hparams["rope_scaling"]["type"] == "linear":
         if hparams["rope_scaling"]["type"] == "linear":
-            gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"])
+            gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+            gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
 
 
 
 
 # TOKENIZATION
 # TOKENIZATION

+ 48 - 49
convert.py

@@ -151,8 +151,11 @@ class Params:
     n_head_kv:  int
     n_head_kv:  int
     f_norm_eps: float
     f_norm_eps: float
 
 
+    rope_scaling_type: gguf.RopeScalingType | None = None
     f_rope_freq_base: float | None = None
     f_rope_freq_base: float | None = None
     f_rope_scale: float | None = None
     f_rope_scale: float | None = None
+    n_orig_ctx: int | None = None
+    rope_finetuned: bool | None = None
 
 
     ftype: GGMLFileType | None = None
     ftype: GGMLFileType | None = None
 
 
@@ -198,20 +201,20 @@ class Params:
     def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
     def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
         config = json.load(open(config_path))
         config = json.load(open(config_path))
 
 
-        n_vocab          = config["vocab_size"]
-        n_embd           = config["hidden_size"]
-        n_layer          = config["num_hidden_layers"]
-        n_ff             = config["intermediate_size"]
-        n_head           = config["num_attention_heads"]
-        n_head_kv        = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
-        f_norm_eps       = config["rms_norm_eps"]
-        f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
-
+        rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
         rope_scaling = config.get("rope_scaling")
         rope_scaling = config.get("rope_scaling")
-        if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear":
-            f_rope_scale = config["rope_scaling"].get("factor")
-        else:
-            f_rope_scale = None
+
+        if rope_scaling is not None and (typ := rope_scaling.get("type")):
+            rope_factor = rope_scaling.get("factor")
+            f_rope_scale = rope_factor
+            if typ == "linear":
+                rope_scaling_type = gguf.RopeScalingType.LINEAR
+            elif typ == "yarn":
+                rope_scaling_type = gguf.RopeScalingType.YARN
+                n_orig_ctx = rope_scaling['original_max_position_embeddings']
+                rope_finetuned = rope_scaling['finetuned']
+            else:
+                raise NotImplementedError(f'Unknown rope scaling type: {typ}')
 
 
         if "max_sequence_length" in config:
         if "max_sequence_length" in config:
             n_ctx = config["max_sequence_length"]
             n_ctx = config["max_sequence_length"]
@@ -222,16 +225,19 @@ class Params:
                             "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
                             "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
 
 
         return Params(
         return Params(
-            n_vocab          = n_vocab,
-            n_embd           = n_embd,
-            n_layer          = n_layer,
-            n_ctx            = n_ctx,
-            n_ff             = n_ff,
-            n_head           = n_head,
-            n_head_kv        = n_head_kv,
-            f_norm_eps       = f_norm_eps,
-            f_rope_freq_base = f_rope_freq_base,
-            f_rope_scale     = f_rope_scale,
+            n_vocab           = config["vocab_size"],
+            n_embd            = config["hidden_size"],
+            n_layer           = config["num_hidden_layers"],
+            n_ctx             = n_ctx,
+            n_ff              = config["intermediate_size"],
+            n_head            = (n_head := config["num_attention_heads"]),
+            n_head_kv         = config.get("num_key_value_heads", n_head),
+            f_norm_eps        = config["rms_norm_eps"],
+            f_rope_freq_base  = config.get("rope_theta"),
+            rope_scaling_type = rope_scaling_type,
+            f_rope_scale      = f_rope_scale,
+            n_orig_ctx        = n_orig_ctx,
+            rope_finetuned    = rope_finetuned,
         )
         )
 
 
     # LLaMA v2 70B params.json
     # LLaMA v2 70B params.json
@@ -240,17 +246,8 @@ class Params:
     def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
     def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
         config = json.load(open(config_path))
         config = json.load(open(config_path))
 
 
-        n_vocab          = config["vocab_size"] if "vocab_size" in config else -1
-        n_embd           = config["dim"]
-        n_layer          = config["n_layers"]
-        n_ff             = -1
-        n_head           = config["n_heads"]
-        n_head_kv        = config["n_kv_heads"] if "n_kv_heads" in config else n_head
-        f_norm_eps       = config["norm_eps"]
-        f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
-
         # hack to determine LLaMA v1 vs v2 vs CodeLlama
         # hack to determine LLaMA v1 vs v2 vs CodeLlama
-        if f_rope_freq_base == 1000000:
+        if config.get("rope_theta") == 1000000:
             # CodeLlama
             # CodeLlama
             n_ctx = 16384
             n_ctx = 16384
         elif config["norm_eps"] == 1e-05:
         elif config["norm_eps"] == 1e-05:
@@ -260,22 +257,16 @@ class Params:
             # LLaMA v1
             # LLaMA v1
             n_ctx = 2048
             n_ctx = 2048
 
 
-        if n_vocab == -1:
-            n_vocab = model["tok_embeddings.weight"].shape[0]
-
-        if n_ff == -1:
-            n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
-
         return Params(
         return Params(
-            n_vocab          = n_vocab,
-            n_embd           = n_embd,
-            n_layer          = n_layer,
+            n_vocab          = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]),
+            n_embd           = config["dim"],
+            n_layer          = config["n_layers"],
             n_ctx            = n_ctx,
             n_ctx            = n_ctx,
-            n_ff             = n_ff,
-            n_head           = n_head,
-            n_head_kv        = n_head_kv,
-            f_norm_eps       = f_norm_eps,
-            f_rope_freq_base = f_rope_freq_base,
+            n_ff             = model["layers.0.feed_forward.w1.weight"].shape[0],
+            n_head           = (n_head := config["n_heads"]),
+            n_head_kv        = config.get("n_kv_heads", n_head),
+            f_norm_eps       = config["norm_eps"],
+            f_rope_freq_base = config.get("rope_theta"),
         )
         )
 
 
     @staticmethod
     @staticmethod
@@ -831,8 +822,16 @@ class OutputFile:
         if params.f_rope_freq_base is not None:
         if params.f_rope_freq_base is not None:
             self.gguf.add_rope_freq_base(params.f_rope_freq_base)
             self.gguf.add_rope_freq_base(params.f_rope_freq_base)
 
 
-        if params.f_rope_scale is not None:
-            self.gguf.add_rope_scale_linear(params.f_rope_scale)
+        if params.rope_scaling_type:
+            assert params.f_rope_scale is not None
+            self.gguf.add_rope_scaling_type(params.rope_scaling_type)
+            self.gguf.add_rope_scaling_factor(params.f_rope_scale)
+
+        if params.n_orig_ctx is not None:
+            self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx)
+
+        if params.rope_finetuned is not None:
+            self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)
 
 
         if params.ftype is not None:
         if params.ftype is not None:
             self.gguf.add_file_type(params.ftype)
             self.gguf.add_file_type(params.ftype)

+ 3 - 2
examples/finetune/finetune.cpp

@@ -642,8 +642,9 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
         const int rope_mode = 0;
         const int rope_mode = 0;
 
 
         return ggml_rope_custom(ctx,
         return ggml_rope_custom(ctx,
-            t, KQ_pos, n_rot, rope_mode, n_ctx,
-            rope_freq_base, rope_freq_scale);
+            t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
+            rope_freq_base, rope_freq_scale, 0.0f, 0.0f, 0.0f, 0.0f
+        );
     };
     };
 
 
     set_name(tokens_input, "tokens_input");
     set_name(tokens_input, "tokens_input");

+ 55 - 4
examples/server/server.cpp

@@ -1755,12 +1755,18 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     printf("options:\n");
     printf("options:\n");
     printf("  -h, --help                show this help message and exit\n");
     printf("  -h, --help                show this help message and exit\n");
     printf("  -v, --verbose             verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
     printf("  -v, --verbose             verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
-    printf("  -t N,  --threads N        number of threads to use during computation (default: %d)\n", params.n_threads);
+    printf("  -t N, --threads N         number of threads to use during computation (default: %d)\n", params.n_threads);
     printf("  -tb N, --threads-batch N  number of threads to use during batch and prompt processing (default: same as --threads)\n");
     printf("  -tb N, --threads-batch N  number of threads to use during batch and prompt processing (default: same as --threads)\n");
-    printf("  -c N,  --ctx-size N       size of the prompt context (default: %d)\n", params.n_ctx);
+    printf("  -c N, --ctx-size N        size of the prompt context (default: %d)\n", params.n_ctx);
+    printf("  --rope-scaling {none,linear,yarn}\n");
+    printf("                            RoPE frequency scaling method, defaults to linear unless specified by the model\n");
     printf("  --rope-freq-base N        RoPE base frequency (default: loaded from model)\n");
     printf("  --rope-freq-base N        RoPE base frequency (default: loaded from model)\n");
-    printf("  --rope-freq-scale N       RoPE frequency scaling factor (default: loaded from model)\n");
-    printf("  -b N,  --batch-size N     batch size for prompt processing (default: %d)\n", params.n_batch);
+    printf("  --rope-freq-scale N       RoPE frequency scaling factor, expands context by a factor of 1/N\n");
+    printf("  --yarn-ext-factor N       YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
+    printf("  --yarn-attn-factor N      YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
+    printf("  --yarn-beta-slow N        YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
+    printf("  --yarn-beta-fast N        YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
+    printf("  -b N, --batch-size N      batch size for prompt processing (default: %d)\n", params.n_batch);
     printf("  --memory-f32              use f32 instead of f16 for memory key+value (default: disabled)\n");
     printf("  --memory-f32              use f32 instead of f16 for memory key+value (default: disabled)\n");
     printf("                            not recommended: doubles context memory required and no measurable increase in quality\n");
     printf("                            not recommended: doubles context memory required and no measurable increase in quality\n");
     if (llama_mlock_supported())
     if (llama_mlock_supported())
@@ -1881,6 +1887,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
             }
             params.n_ctx = std::stoi(argv[i]);
             params.n_ctx = std::stoi(argv[i]);
         }
         }
+        else if (arg == "--rope-scaling")
+        {
+            if (++i >= argc)
+            {
+                invalid_param = true;
+                break;
+            }
+            std::string value(argv[i]);
+            /**/ if (value == "none")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
+            else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
+            else if (value == "yarn")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
+            else { invalid_param = true; break; }
+        }
         else if (arg == "--rope-freq-base")
         else if (arg == "--rope-freq-base")
         {
         {
             if (++i >= argc)
             if (++i >= argc)
@@ -1899,6 +1918,38 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
             }
             params.rope_freq_scale = std::stof(argv[i]);
             params.rope_freq_scale = std::stof(argv[i]);
         }
         }
+        else if (arg == "--yarn-ext-factor")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_ext_factor = std::stof(argv[i]);
+        }
+        else if (arg == "--yarn-attn-factor")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_attn_factor = std::stof(argv[i]);
+        }
+        else if (arg == "--yarn-beta-fast")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_beta_fast = std::stof(argv[i]);
+        }
+        else if (arg == "--yarn-beta-slow")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_beta_slow = std::stof(argv[i]);
+        }
         else if (arg == "--memory-f32" || arg == "--memory_f32")
         else if (arg == "--memory-f32" || arg == "--memory_f32")
         {
         {
             params.memory_f16 = false;
             params.memory_f16 = false;

+ 3 - 3
examples/train-text-from-scratch/train-text-from-scratch.cpp

@@ -349,9 +349,9 @@ static struct ggml_tensor * llama_build_train_graphs(
         // not capturing these, to silcence warnings
         // not capturing these, to silcence warnings
         const int rope_mode = 0;
         const int rope_mode = 0;
 
 
-        return ggml_rope_custom(ctx,
-            t, KQ_pos, n_rot, rope_mode, n_ctx,
-            rope_freq_base, rope_freq_scale);
+        return ggml_rope_custom(
+            ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
+        );
     };
     };
 
 
     set_name(tokens_input, "tokens_input");
     set_name(tokens_input, "tokens_input");

+ 111 - 42
ggml-cuda.cu

@@ -4493,11 +4493,41 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
     cpy_1(cx + x_offset, cdst + dst_offset);
     cpy_1(cx + x_offset, cdst + dst_offset);
 }
 }
 
 
-// rope == RoPE == rotary positional embedding
+static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+struct rope_corr_dims {
+    float v[4];
+};
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static __device__ void rope_yarn(
+    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 
 
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+// rope == RoPE == rotary positional embedding
 template<typename T, bool has_pos>
 template<typename T, bool has_pos>
-static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
-                            const int p_delta_rows, const float theta_scale) {
+static __global__ void rope(
+    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims
+) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
 
     if (col >= ncols) {
     if (col >= ncols) {
@@ -4509,10 +4539,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
     const int i2 = row/p_delta_rows;
     const int i2 = row/p_delta_rows;
 
 
     const int p = has_pos ? pos[i2] : 0;
     const int p = has_pos ? pos[i2] : 0;
-    const float p0 = p*freq_scale;
-    const float theta = p0*powf(theta_scale, col/2);
-    const float sin_theta = sinf(theta);
-    const float cos_theta = cosf(theta);
+    const float theta_base = p*powf(freq_base, -col/ncols);
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
 
     const float x0 = x[i + 0];
     const float x0 = x[i + 0];
     const float x1 = x[i + 1];
     const float x1 = x[i + 1];
@@ -4522,8 +4552,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
 }
 }
 
 
 template<typename T, bool has_pos>
 template<typename T, bool has_pos>
-static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
-                                 const int p_delta_rows, const float theta_scale) {
+static __global__ void rope_neox(
+    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims
+) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
 
     if (col >= ncols) {
     if (col >= ncols) {
@@ -4534,11 +4566,14 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
     const int i = row*ncols + col/2;
     const int i = row*ncols + col/2;
     const int i2 = row/p_delta_rows;
     const int i2 = row/p_delta_rows;
 
 
+    // simplified from `(row * ncols + col) * (-1 / ncols)`
+    const float cur_rot = -col/ncols - row;
+
     const int p = has_pos ? pos[i2] : 0;
     const int p = has_pos ? pos[i2] : 0;
-    const float p0 = p*freq_scale;
-    const float theta = p0*powf(theta_scale, col/2);
-    const float sin_theta = sinf(theta);
-    const float cos_theta = cosf(theta);
+    const float theta_base = p*powf(freq_base, cur_rot);
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
 
     const float x0 = x[i + 0];
     const float x0 = x[i + 0];
     const float x1 = x[i + ncols/2];
     const float x1 = x[i + ncols/2];
@@ -4547,8 +4582,10 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
     dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
     dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
 }
 }
 
 
-static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
-                                    const int p_delta_rows, const float theta_scale, const int n_ctx) {
+static __global__ void rope_glm_f32(
+    const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+    int n_ctx
+) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int half_n_dims = ncols/4;
     const int half_n_dims = ncols/4;
 
 
@@ -4560,7 +4597,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
     const int i = row*ncols + col;
     const int i = row*ncols + col;
     const int i2 = row/p_delta_rows;
     const int i2 = row/p_delta_rows;
 
 
-    const float col_theta_scale = powf(theta_scale, col);
+    const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
      // FIXME: this is likely wrong
      // FIXME: this is likely wrong
     const int p = pos != nullptr ? pos[i2] : 0;
     const int p = pos != nullptr ? pos[i2] : 0;
 
 
@@ -5584,40 +5621,54 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const
 }
 }
 
 
 template<typename T>
 template<typename T>
-static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
-                          const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+static void rope_cuda(
+    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+) {
     GGML_ASSERT(ncols % 2 == 0);
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
     const dim3 block_nums(nrows, num_blocks_x, 1);
     if (pos == nullptr) {
     if (pos == nullptr) {
-        rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope<T, false><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     } else {
     } else {
-        rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope<T, true><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     }
     }
 }
 }
 
 
 template<typename T>
 template<typename T>
-static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
-                          const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+static void rope_neox_cuda(
+    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+) {
     GGML_ASSERT(ncols % 2 == 0);
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
     const dim3 block_nums(nrows, num_blocks_x, 1);
     if (pos == nullptr) {
     if (pos == nullptr) {
-        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     } else {
     } else {
-        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     }
     }
 }
 }
 
 
-static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
-                              const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
+static void rope_glm_f32_cuda(
+    const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, int n_ctx, cudaStream_t stream
+) {
     GGML_ASSERT(ncols % 4 == 0);
     GGML_ASSERT(ncols % 4 == 0);
     const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
     const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
     const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
     const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
     const dim3 block_nums(num_blocks_x, nrows, 1);
     const dim3 block_nums(num_blocks_x, nrows, 1);
-    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx);
+    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
 }
 }
 
 
 static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
 static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@@ -6477,17 +6528,20 @@ inline void ggml_cuda_op_rope(
     const int64_t ne2 = dst->ne[2];
     const int64_t ne2 = dst->ne[2];
     const int64_t nrows = ggml_nrows(src0);
     const int64_t nrows = ggml_nrows(src0);
 
 
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-    const int n_ctx  = ((int32_t *) dst->op_params)[3];
-    // RoPE alteration for extended context
-
-    float freq_base, freq_scale;
-    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+    //const int n_past      = ((int32_t *) dst->op_params)[0];
+    const int n_dims      = ((int32_t *) dst->op_params)[1];
+    const int mode        = ((int32_t *) dst->op_params)[2];
+    const int n_ctx       = ((int32_t *) dst->op_params)[3];
+    const int n_orig_ctx  = ((int32_t *) dst->op_params)[4];
 
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    // RoPE alteration for extended context
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
 
     const int32_t * pos = nullptr;
     const int32_t * pos = nullptr;
     if ((mode & 1) == 0) {
     if ((mode & 1) == 0) {
@@ -6499,24 +6553,39 @@ inline void ggml_cuda_op_rope(
     const bool is_neox = mode & 2;
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
     const bool is_glm  = mode & 4;
 
 
+    rope_corr_dims corr_dims;
+    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
+
     // compute
     // compute
     if (is_glm) {
     if (is_glm) {
         GGML_ASSERT(false);
         GGML_ASSERT(false);
-        rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
+        rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
     } else if (is_neox) {
     } else if (is_neox) {
         GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
         GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
         if (src0->type == GGML_TYPE_F32) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_neox_cuda(
+                (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else if (src0->type == GGML_TYPE_F16) {
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_neox_cuda(
+                (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else {
         } else {
             GGML_ASSERT(false);
             GGML_ASSERT(false);
         }
         }
     } else {
     } else {
         if (src0->type == GGML_TYPE_F32) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_cuda(
+                (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else if (src0->type == GGML_TYPE_F16) {
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_cuda(
+                (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else {
         } else {
             GGML_ASSERT(false);
             GGML_ASSERT(false);
         }
         }

+ 16 - 8
ggml-metal.m

@@ -1400,14 +1400,18 @@ void ggml_metal_graph_compute(
 
 
                             const int nth = MIN(1024, ne00);
                             const int nth = MIN(1024, ne00);
 
 
-                            const int n_past = ((int32_t *) dst->op_params)[0];
-                            const int n_dims = ((int32_t *) dst->op_params)[1];
-                            const int mode   = ((int32_t *) dst->op_params)[2];
-
-                            float freq_base;
-                            float freq_scale;
-                            memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-                            memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+                            const int n_past     = ((int32_t *) dst->op_params)[0];
+                            const int n_dims     = ((int32_t *) dst->op_params)[1];
+                            const int mode       = ((int32_t *) dst->op_params)[2];
+                            const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
+
+                            float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+                            memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+                            memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+                            memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+                            memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+                            memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+                            memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
 
                             switch (src0->type) {
                             switch (src0->type) {
                                 case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
                                 case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
@@ -1439,6 +1443,10 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&mode    length:sizeof(     int) atIndex:21];
                             [encoder setBytes:&mode    length:sizeof(     int) atIndex:21];
                             [encoder setBytes:&freq_base  length:sizeof(float) atIndex:22];
                             [encoder setBytes:&freq_base  length:sizeof(float) atIndex:22];
                             [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
                             [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
+                            [encoder setBytes:&ext_factor  length:sizeof(float) atIndex:24];
+                            [encoder setBytes:&attn_factor length:sizeof(float) atIndex:25];
+                            [encoder setBytes:&beta_fast   length:sizeof(float) atIndex:26];
+                            [encoder setBytes:&beta_slow   length:sizeof(float) atIndex:27];
 
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                         } break;

+ 55 - 6
ggml-metal.metal

@@ -1061,6 +1061,45 @@ kernel void kernel_alibi_f32(
     }
     }
 }
 }
 
 
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
+    return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+static void rope_yarn_corr_dims(
+    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+    // start and end correction dims
+    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
+    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
+
 typedef void (rope_t)(
 typedef void (rope_t)(
         device const    void * src0,
         device const    void * src0,
         device const int32_t * src1,
         device const int32_t * src1,
@@ -1116,6 +1155,10 @@ kernel void kernel_rope(
         constant         int & mode,
         constant         int & mode,
         constant       float & freq_base,
         constant       float & freq_base,
         constant       float & freq_scale,
         constant       float & freq_scale,
+        constant       float & ext_factor,
+        constant       float & attn_factor,
+        constant       float & beta_fast,
+        constant       float & beta_slow,
         uint  tiitg[[thread_index_in_threadgroup]],
         uint  tiitg[[thread_index_in_threadgroup]],
         uint3 tptg[[threads_per_threadgroup]],
         uint3 tptg[[threads_per_threadgroup]],
         uint3 tgpig[[threadgroup_position_in_grid]]) {
         uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -1125,19 +1168,22 @@ kernel void kernel_rope(
 
 
     const bool is_neox = mode & 2;
     const bool is_neox = mode & 2;
 
 
+    float corr_dims[2];
+    rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+
     device const int32_t * pos = src1;
     device const int32_t * pos = src1;
 
 
     const int64_t p = pos[i2];
     const int64_t p = pos[i2];
 
 
-    const float theta_0 = freq_scale * (float)p;
+    const float theta_0 = (float)p;
     const float inv_ndims = -1.f/n_dims;
     const float inv_ndims = -1.f/n_dims;
 
 
     if (!is_neox) {
     if (!is_neox) {
         for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
         for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
 
 
             const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
             const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
-            const float cos_theta = cos(theta);
-            const float sin_theta = sin(theta);
+            float cos_theta, sin_theta;
+            rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
 
             device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
             device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
             device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
             device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -1152,9 +1198,12 @@ kernel void kernel_rope(
         for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
         for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
             for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
             for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
 
 
-                const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
-                const float cos_theta = cos(theta);
-                const float sin_theta = sin(theta);
+                // simplified from `(ib * n_dims + ic) * inv_ndims`
+                const float cur_rot = inv_ndims*ic - ib;
+
+                const float theta = theta_0 * pow(freq_base, cur_rot);
+                float cos_theta, sin_theta;
+                rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
 
                 const int64_t i0 = ib*n_dims + ic/2;
                 const int64_t i0 = ib*n_dims + ic/2;
 
 

+ 175 - 66
ggml.c

@@ -1,4 +1,5 @@
 #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
 #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
+#define _USE_MATH_DEFINES // For M_PI on MSVC
 
 
 #include "ggml-impl.h"
 #include "ggml-impl.h"
 #include "ggml-quants.h"
 #include "ggml-quants.h"
@@ -4845,8 +4846,13 @@ static struct ggml_tensor * ggml_rope_impl(
         int                   n_dims,
         int                   n_dims,
         int                   mode,
         int                   mode,
         int                   n_ctx,
         int                   n_ctx,
+        int                   n_orig_ctx,
         float                 freq_base,
         float                 freq_base,
         float                 freq_scale,
         float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow,
         float                 xpos_base,
         float                 xpos_base,
         bool                  xpos_down,
         bool                  xpos_down,
         bool                  inplace) {
         bool                  inplace) {
@@ -4862,11 +4868,15 @@ static struct ggml_tensor * ggml_rope_impl(
 
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
 
-    int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
-    memcpy(params + 4, &freq_base,  sizeof(float));
-    memcpy(params + 5, &freq_scale, sizeof(float));
-    memcpy(params + 6, &xpos_base,  sizeof(float));
-    memcpy(params + 7, &xpos_down,  sizeof(bool));
+    int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
+    memcpy(params +  5, &freq_base,    sizeof(float));
+    memcpy(params +  6, &freq_scale,   sizeof(float));
+    memcpy(params +  7, &ext_factor,   sizeof(float));
+    memcpy(params +  8, &attn_factor,  sizeof(float));
+    memcpy(params +  9, &beta_fast,    sizeof(float));
+    memcpy(params + 10, &beta_slow,    sizeof(float));
+    memcpy(params + 11, &xpos_base,    sizeof(float));
+    memcpy(params + 12, &xpos_down,    sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
     ggml_set_op_params(result, params, sizeof(params));
 
 
     result->op   = GGML_OP_ROPE;
     result->op   = GGML_OP_ROPE;
@@ -4884,7 +4894,9 @@ struct ggml_tensor * ggml_rope(
         int                   n_dims,
         int                   n_dims,
         int                   mode,
         int                   mode,
         int                   n_ctx) {
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
+    );
 }
 }
 
 
 struct ggml_tensor * ggml_rope_inplace(
 struct ggml_tensor * ggml_rope_inplace(
@@ -4894,7 +4906,9 @@ struct ggml_tensor * ggml_rope_inplace(
         int                   n_dims,
         int                   n_dims,
         int                   mode,
         int                   mode,
         int                   n_ctx) {
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
+    );
 }
 }
 
 
 struct ggml_tensor * ggml_rope_custom(
 struct ggml_tensor * ggml_rope_custom(
@@ -4904,9 +4918,17 @@ struct ggml_tensor * ggml_rope_custom(
         int                   n_dims,
         int                   n_dims,
         int                   mode,
         int                   mode,
         int                   n_ctx,
         int                   n_ctx,
+        int                   n_orig_ctx,
         float                 freq_base,
         float                 freq_base,
-        float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
+    );
 }
 }
 
 
 struct ggml_tensor * ggml_rope_custom_inplace(
 struct ggml_tensor * ggml_rope_custom_inplace(
@@ -4916,9 +4938,17 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         int                   n_dims,
         int                   n_dims,
         int                   mode,
         int                   mode,
         int                   n_ctx,
         int                   n_ctx,
+        int                   n_orig_ctx,
         float                 freq_base,
         float                 freq_base,
-        float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
+    );
 }
 }
 
 
 struct ggml_tensor * ggml_rope_xpos_inplace(
 struct ggml_tensor * ggml_rope_xpos_inplace(
@@ -4928,7 +4958,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
         int                   n_dims,
         int                   n_dims,
         float                 base,
         float                 base,
         bool                  down) {
         bool                  down) {
-    return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
+    return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
 }
 }
 
 
 // ggml_rope_back
 // ggml_rope_back
@@ -10901,6 +10931,45 @@ static void ggml_compute_forward_clamp(
 
 
 // ggml_compute_forward_rope
 // ggml_compute_forward_rope
 
 
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+    return 1 - MIN(1, MAX(0, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
+    return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
+}
+
+void ggml_rope_yarn_corr_dims(
+    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+    // start and end correction dims
+    dims[0] = MAX(0,         floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
+    dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
+
 static void ggml_compute_forward_rope_f32(
 static void ggml_compute_forward_rope_f32(
         const struct ggml_compute_params * params,
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src0,
@@ -10910,21 +10979,26 @@ static void ggml_compute_forward_rope_f32(
         return;
         return;
     }
     }
 
 
-    float freq_base;
-    float freq_scale;
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
 
     // these two only relevant for xPos RoPE:
     // these two only relevant for xPos RoPE:
     float xpos_base;
     float xpos_base;
     bool  xpos_down;
     bool  xpos_down;
 
 
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-    const int n_ctx  = ((int32_t *) dst->op_params)[3];
-    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
-    memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
-    memcpy(&xpos_down,  (int32_t *) dst->op_params + 7, sizeof(bool));
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(&xpos_base,   (int32_t *) dst->op_params + 11, sizeof(float));
+    memcpy(&xpos_down,   (int32_t *) dst->op_params + 12, sizeof(bool));
 
 
     GGML_TENSOR_UNARY_OP_LOCALS
     GGML_TENSOR_UNARY_OP_LOCALS
 
 
@@ -10952,6 +11026,9 @@ static void ggml_compute_forward_rope_f32(
     int ir = 0;
     int ir = 0;
 
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float inv_ndims = -1.f/n_dims;
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
 
 
     const bool is_neox = mode & 2;
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
     const bool is_glm  = mode & 4;
@@ -10965,18 +11042,18 @@ static void ggml_compute_forward_rope_f32(
                 if (ir++ < ir0) continue;
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
                 if (ir   > ir1) break;
 
 
-                float theta = freq_scale * (float)p;
+                float theta_base = (float)p;
 
 
                 if (is_glm) {
                 if (is_glm) {
-                    theta = MIN(p, n_ctx - 2);
+                    theta_base = MIN(p, n_ctx - 2);
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
                         const float cos_block_theta = cosf(block_theta);
                         const float cos_block_theta = cosf(block_theta);
                         const float sin_block_theta = sinf(block_theta);
                         const float sin_block_theta = sinf(block_theta);
 
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
                         block_theta *= theta_scale;
                         block_theta *= theta_scale;
 
 
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -10994,13 +11071,16 @@ static void ggml_compute_forward_rope_f32(
                     }
                     }
                 } else if (!is_neox) {
                 } else if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        float cos_theta, sin_theta;
+                        rope_yarn(
+                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
+                        );
+
                         // zeta scaling for xPos only:
                         // zeta scaling for xPos only:
                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
                         if (xpos_down) zeta = 1.0f / zeta;
                         if (xpos_down) zeta = 1.0f / zeta;
 
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
 
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
                               float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -11014,12 +11094,19 @@ static void ggml_compute_forward_rope_f32(
                 } else {
                 } else {
                     // TODO: this might be wrong for ne0 != n_dims - need double check
                     // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    theta_base *= freq_scale;
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            // simplified from `(ib * n_dims + ic) * inv_ndims`
+                            float cur_rot = inv_ndims * ic - ib;
+
+                            float cos_theta, sin_theta;
+                            rope_yarn(
+                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+                                &cos_theta, &sin_theta
+                            );
 
 
-                            theta *= theta_scale;
+                            theta_base *= theta_scale;
 
 
                             const int64_t i0 = ib*n_dims + ic/2;
                             const int64_t i0 = ib*n_dims + ic/2;
 
 
@@ -11048,15 +11135,19 @@ static void ggml_compute_forward_rope_f16(
         return;
         return;
     }
     }
 
 
-    float freq_base;
-    float freq_scale;
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
 
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-    const int n_ctx  = ((int32_t *) dst->op_params)[3];
-    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
 
     GGML_TENSOR_UNARY_OP_LOCALS
     GGML_TENSOR_UNARY_OP_LOCALS
 
 
@@ -11084,6 +11175,9 @@ static void ggml_compute_forward_rope_f16(
     int ir = 0;
     int ir = 0;
 
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float inv_ndims = -1.f/n_dims;
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
 
 
     const bool is_neox = mode & 2;
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
     const bool is_glm  = mode & 4;
@@ -11097,18 +11191,18 @@ static void ggml_compute_forward_rope_f16(
                 if (ir++ < ir0) continue;
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
                 if (ir   > ir1) break;
 
 
-                float theta = freq_scale * (float)p;
+                float theta_base = (float)p;
 
 
                 if (is_glm) {
                 if (is_glm) {
-                    theta = MIN(p, n_ctx - 2);
+                    theta_base = MIN(p, n_ctx - 2);
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
                         const float cos_block_theta = cosf(block_theta);
                         const float cos_block_theta = cosf(block_theta);
                         const float sin_block_theta = sinf(block_theta);
                         const float sin_block_theta = sinf(block_theta);
 
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
                         block_theta *= theta_scale;
                         block_theta *= theta_scale;
 
 
                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -11126,10 +11220,12 @@ static void ggml_compute_forward_rope_f16(
                     }
                     }
                 } else if (!is_neox) {
                 } else if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        float cos_theta, sin_theta;
+                        rope_yarn(
+                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
+                        );
 
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
 
                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
                               ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -11143,12 +11239,19 @@ static void ggml_compute_forward_rope_f16(
                 } else {
                 } else {
                     // TODO: this might be wrong for ne0 != n_dims - need double check
                     // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    theta_base *= freq_scale;
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            // simplified from `(ib * n_dims + ic) * inv_ndims`
+                            float cur_rot = inv_ndims * ic - ib;
 
 
-                            theta *= theta_scale;
+                            float cos_theta, sin_theta;
+                            rope_yarn(
+                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+                                &cos_theta, &sin_theta
+                            );
+
+                            theta_base *= theta_scale;
 
 
                             const int64_t i0 = ib*n_dims + ic/2;
                             const int64_t i0 = ib*n_dims + ic/2;
 
 
@@ -11256,17 +11359,18 @@ static void ggml_compute_forward_rope_back_f32(
                 if (ir++ < ir0) continue;
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
                 if (ir   > ir1) break;
 
 
-                float theta = freq_scale * (float)p;
+                float theta_base = freq_scale * (float)p;
 
 
                 if (!is_neox) {
                 if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
+
                         // zeta scaling for xPos only:
                         // zeta scaling for xPos only:
                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
                         if (xpos_down) zeta = 1.0f / zeta;
                         if (xpos_down) zeta = 1.0f / zeta;
 
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
 
                         const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                         const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
                               float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -11280,10 +11384,10 @@ static void ggml_compute_forward_rope_back_f32(
                 } else {
                 } else {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            const float cos_theta = cosf(theta_base);
+                            const float sin_theta = sinf(theta_base);
 
 
-                            theta *= theta_scale;
+                            theta_base *= theta_scale;
 
 
                             const int64_t i0 = ib*n_dims + ic/2;
                             const int64_t i0 = ib*n_dims + ic/2;
 
 
@@ -11356,14 +11460,14 @@ static void ggml_compute_forward_rope_back_f16(
                 if (ir++ < ir0) continue;
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
                 if (ir   > ir1) break;
 
 
-                float theta = (float)p;
+                float theta_base = (float)p;
 
 
                 if (!is_neox) {
                 if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
 
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
 
                         const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                         const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
                               ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -11377,10 +11481,10 @@ static void ggml_compute_forward_rope_back_f16(
                 } else {
                 } else {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            const float cos_theta = cosf(theta_base);
+                            const float sin_theta = sinf(theta_base);
 
 
-                            theta *= theta_scale;
+                            theta_base *= theta_scale;
 
 
                             const int64_t i0 = ib*n_dims + ic/2;
                             const int64_t i0 = ib*n_dims + ic/2;
 
 
@@ -15505,9 +15609,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src1,
                                 src1,
                                 n_dims,
                                 n_dims,
                                 mode,
                                 mode,
+                                0,
                                 n_ctx,
                                 n_ctx,
                                 freq_base,
                                 freq_base,
                                 freq_scale,
                                 freq_scale,
+                                0.0f,
+                                1.0f,
+                                0.0f,
+                                0.0f,
                                 xpos_base,
                                 xpos_base,
                                 xpos_down,
                                 xpos_down,
                                 false),
                                 false),

+ 17 - 3
ggml.h

@@ -219,7 +219,7 @@
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_SRC           6
 #define GGML_MAX_SRC           6
 #define GGML_MAX_NAME          64
 #define GGML_MAX_NAME          64
-#define GGML_MAX_OP_PARAMS     32
+#define GGML_MAX_OP_PARAMS     64
 #define GGML_DEFAULT_N_THREADS 4
 #define GGML_DEFAULT_N_THREADS 4
 
 
 #if UINTPTR_MAX == 0xFFFFFFFF
 #if UINTPTR_MAX == 0xFFFFFFFF
@@ -1326,8 +1326,13 @@ extern "C" {
             int                   n_dims,
             int                   n_dims,
             int                   mode,
             int                   mode,
             int                   n_ctx,
             int                   n_ctx,
+            int                   n_orig_ctx,
             float                 freq_base,
             float                 freq_base,
-            float                 freq_scale);
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
 
 
     // in-place, returns view(a)
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
     GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
@@ -1337,8 +1342,17 @@ extern "C" {
             int                   n_dims,
             int                   n_dims,
             int                   mode,
             int                   mode,
             int                   n_ctx,
             int                   n_ctx,
+            int                   n_orig_ctx,
             float                 freq_base,
             float                 freq_base,
-            float                 freq_scale);
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+    // compute correction dims for YaRN RoPE scaling
+    void ggml_rope_yarn_corr_dims(
+        int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
 
 
     // xPos RoPE, in-place, returns view(a)
     // xPos RoPE, in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
     GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(

+ 23 - 6
gguf-py/gguf/gguf.py

@@ -7,7 +7,7 @@ import shutil
 import struct
 import struct
 import sys
 import sys
 import tempfile
 import tempfile
-from enum import IntEnum, auto
+from enum import Enum, IntEnum, auto
 from io import BufferedWriter
 from io import BufferedWriter
 from pathlib import Path
 from pathlib import Path
 from typing import IO, Any, BinaryIO, Callable, Sequence
 from typing import IO, Any, BinaryIO, Callable, Sequence
@@ -53,9 +53,12 @@ KEY_ATTENTION_LAYERNORM_EPS     = "{arch}.attention.layer_norm_epsilon"
 KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
 KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
 
 
 # RoPE
 # RoPE
-KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
-KEY_ROPE_FREQ_BASE       = "{arch}.rope.freq_base"
-KEY_ROPE_SCALE_LINEAR    = "{arch}.rope.scale_linear"
+KEY_ROPE_DIMENSION_COUNT         = "{arch}.rope.dimension_count"
+KEY_ROPE_FREQ_BASE               = "{arch}.rope.freq_base"
+KEY_ROPE_SCALING_TYPE            = "{arch}.rope.scaling.type"
+KEY_ROPE_SCALING_FACTOR          = "{arch}.rope.scaling.factor"
+KEY_ROPE_SCALING_ORIG_CTX_LEN    = "{arch}.rope.scaling.original_context_length"
+KEY_ROPE_SCALING_FINETUNED       = "{arch}.rope.scaling.finetuned"
 
 
 # tokenization
 # tokenization
 KEY_TOKENIZER_MODEL      = "tokenizer.ggml.model"
 KEY_TOKENIZER_MODEL      = "tokenizer.ggml.model"
@@ -577,6 +580,11 @@ class TokenType(IntEnum):
     UNUSED       = 5
     UNUSED       = 5
     BYTE         = 6
     BYTE         = 6
 
 
+class RopeScalingType(Enum):
+    NONE   = 'none'
+    LINEAR = 'linear'
+    YARN   = 'yarn'
+
 #
 #
 # implementation
 # implementation
 #
 #
@@ -948,8 +956,17 @@ class GGUFWriter:
     def add_rope_freq_base(self, value: float):
     def add_rope_freq_base(self, value: float):
         self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
         self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
 
 
-    def add_rope_scale_linear(self, value: float):
-        self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value)
+    def add_rope_scaling_type(self, value: RopeScalingType):
+        self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value)
+
+    def add_rope_scaling_factor(self, value: float):
+        self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)
+
+    def add_rope_scaling_orig_ctx_len(self, value: int):
+        self.add_uint32(KEY_ROPE_SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
+
+    def add_rope_scaling_finetuned(self, value: bool):
+        self.add_bool(KEY_ROPE_SCALING_FINETUNED.format(arch=self.arch), value)
 
 
     def add_tokenizer_model(self, model: str):
     def add_tokenizer_model(self, model: str):
         self.add_string(KEY_TOKENIZER_MODEL, model)
         self.add_string(KEY_TOKENIZER_MODEL, model)

+ 167 - 53
llama.cpp

@@ -54,6 +54,7 @@
 #include <cassert>
 #include <cassert>
 #include <cinttypes>
 #include <cinttypes>
 #include <climits>
 #include <climits>
+#include <cmath>
 #include <cstdarg>
 #include <cstdarg>
 #include <cstddef>
 #include <cstddef>
 #include <cstdint>
 #include <cstdint>
@@ -235,6 +236,10 @@ enum llm_kv {
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_FREQ_BASE,
     LLM_KV_ROPE_FREQ_BASE,
     LLM_KV_ROPE_SCALE_LINEAR,
     LLM_KV_ROPE_SCALE_LINEAR,
+    LLM_KV_ROPE_SCALING_TYPE,
+    LLM_KV_ROPE_SCALING_FACTOR,
+    LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
+    LLM_KV_ROPE_SCALING_FINETUNED,
 
 
     LLM_KV_TOKENIZER_MODEL,
     LLM_KV_TOKENIZER_MODEL,
     LLM_KV_TOKENIZER_LIST,
     LLM_KV_TOKENIZER_LIST,
@@ -276,9 +281,13 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_LAYERNORM_EPS,       "%s.attention.layer_norm_epsilon"     },
     { LLM_KV_ATTENTION_LAYERNORM_EPS,       "%s.attention.layer_norm_epsilon"     },
     { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,   "%s.attention.layer_norm_rms_epsilon" },
     { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,   "%s.attention.layer_norm_rms_epsilon" },
 
 
-    { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count" },
-    { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"       },
-    { LLM_KV_ROPE_SCALE_LINEAR,             "%s.rope.scale_linear"    },
+    { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 },
+    { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       },
+    { LLM_KV_ROPE_SCALE_LINEAR,             "%s.rope.scale_linear"                    },
+    { LLM_KV_ROPE_SCALING_TYPE,             "%s.rope.scaling.type"                    },
+    { LLM_KV_ROPE_SCALING_FACTOR,           "%s.rope.scaling.factor"                  },
+    { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,     "%s.rope.scaling.original_context_length" },
+    { LLM_KV_ROPE_SCALING_FINETUNED,        "%s.rope.scaling.finetuned"               },
 
 
     { LLM_KV_TOKENIZER_MODEL,               "tokenizer.ggml.model"              },
     { LLM_KV_TOKENIZER_MODEL,               "tokenizer.ggml.model"              },
     { LLM_KV_TOKENIZER_LIST,                "tokenizer.ggml.tokens"             },
     { LLM_KV_TOKENIZER_LIST,                "tokenizer.ggml.tokens"             },
@@ -552,6 +561,22 @@ do { \
     } \
     } \
 } while (0)
 } while (0)
 
 
+static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
+    { LLAMA_ROPE_SCALING_NONE,   "none"   },
+    { LLAMA_ROPE_SCALING_LINEAR, "linear" },
+    { LLAMA_ROPE_SCALING_YARN,   "yarn"   },
+};
+
+static int8_t llama_rope_scaling_type_from_string(const std::string & name) {
+    for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
+        if (kv.second == name) {
+            return kv.first;
+        }
+    }
+
+    return LLAMA_ROPE_SCALING_UNSPECIFIED;
+}
+
 //
 //
 // ggml helpers
 // ggml helpers
 //
 //
@@ -1035,8 +1060,11 @@ struct llama_hparams {
     float f_norm_eps;
     float f_norm_eps;
     float f_norm_rms_eps;
     float f_norm_rms_eps;
 
 
-    float rope_freq_base_train;
-    float rope_freq_scale_train;
+    float    rope_freq_base_train;
+    float    rope_freq_scale_train;
+    uint32_t n_yarn_orig_ctx;
+    int8_t   rope_scaling_type_train : 3;
+    bool     rope_finetuned : 1;
 
 
     float f_clamp_kqv;
     float f_clamp_kqv;
     float f_max_alibi_bias;
     float f_max_alibi_bias;
@@ -1051,6 +1079,8 @@ struct llama_hparams {
         if (this->n_layer     != other.n_layer)     return true;
         if (this->n_layer     != other.n_layer)     return true;
         if (this->n_rot       != other.n_rot)       return true;
         if (this->n_rot       != other.n_rot)       return true;
         if (this->n_ff        != other.n_ff)        return true;
         if (this->n_ff        != other.n_ff)        return true;
+        if (this->rope_finetuned  != other.rope_finetuned)  return true;
+        if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
 
 
         const float EPSILON = 1e-9;
         const float EPSILON = 1e-9;
 
 
@@ -1081,8 +1111,16 @@ struct llama_cparams {
     uint32_t n_threads;       // number of threads to use for generation
     uint32_t n_threads;       // number of threads to use for generation
     uint32_t n_threads_batch; // number of threads to use for batch processing
     uint32_t n_threads_batch; // number of threads to use for batch processing
 
 
-    float rope_freq_base;
-    float rope_freq_scale;
+    float    rope_freq_base;
+    float    rope_freq_scale;
+
+    uint32_t n_yarn_orig_ctx;
+    // These hyperparameters are not exposed in GGUF, because all
+    // existing YaRN models use the same values for them.
+    float yarn_ext_factor;
+    float yarn_attn_factor;
+    float yarn_beta_fast;
+    float yarn_beta_slow;
 
 
     bool mul_mat_q;
     bool mul_mat_q;
 };
 };
@@ -2014,14 +2052,30 @@ static void llm_load_hparams(
     hparams.n_head_kv = hparams.n_head;
     hparams.n_head_kv = hparams.n_head;
     GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
     GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
 
 
+    hparams.rope_finetuned = false;
+    GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false,
+                 kv(LLM_KV_ROPE_SCALING_FINETUNED));
+
+    hparams.n_yarn_orig_ctx = hparams.n_ctx_train;
+    GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
+                 kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN));
+
     // rope_freq_base (optional)
     // rope_freq_base (optional)
     hparams.rope_freq_base_train = 10000.0f;
     hparams.rope_freq_base_train = 10000.0f;
     GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
     GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
 
 
+    std::string rope_scaling("linear");
+    GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE));
+    hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
+    GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);
+
     // rope_freq_scale (inverse of the kv) is optional
     // rope_freq_scale (inverse of the kv) is optional
-    float ropescale = 1.0f;
-    GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
-    hparams.rope_freq_scale_train = 1.0f/ropescale;
+    float ropescale = 0.0f;
+    GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR));
+    if (ropescale == 0.0f) { // try the old key name
+        GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+    }
+    hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
 
 
     // sanity check for n_rot (optional)
     // sanity check for n_rot (optional)
     {
     {
@@ -2371,6 +2425,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     const auto & hparams = model.hparams;
     const auto & hparams = model.hparams;
     const auto & vocab   = model.vocab;
     const auto & vocab   = model.vocab;
 
 
+    const auto rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
+
     // hparams
     // hparams
     LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
     LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
     LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
     LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
@@ -2389,8 +2445,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
     LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
     LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
     LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
     LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
     LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
+    LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type.c_str());
     LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
     LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
     LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
     LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
+    LLAMA_LOG_INFO("%s: n_yarn_orig_ctx  = %u\n",     __func__, hparams.n_yarn_orig_ctx);
+    LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
     LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
     LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
     LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, ml.n_elements*1e-9);
     LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, ml.n_elements*1e-9);
@@ -3047,21 +3106,11 @@ static void llm_load_tensors(
     model.t_load_us = ggml_time_us() - model.t_start_us;
     model.t_load_us = ggml_time_us() - model.t_start_us;
 }
 }
 
 
-static bool llama_model_load(
-        const std::string & fname,
-        llama_model & model,
-        int n_gpu_layers,
-        int main_gpu,
-        const float * tensor_split,
-        bool use_mmap,
-        bool use_mlock,
-        bool vocab_only,
-        llama_progress_callback progress_callback,
-        void *progress_callback_user_data) {
+static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
     try {
     try {
-        llama_model_loader ml(fname, use_mmap);
+        llama_model_loader ml(fname, params.use_mmap);
 
 
-        model.hparams.vocab_only = vocab_only;
+        model.hparams.vocab_only = params.vocab_only;
 
 
         llm_load_arch   (ml, model);
         llm_load_arch   (ml, model);
         llm_load_hparams(ml, model);
         llm_load_hparams(ml, model);
@@ -3073,15 +3122,15 @@ static bool llama_model_load(
             throw std::runtime_error("vocab size mismatch");
             throw std::runtime_error("vocab size mismatch");
         }
         }
 
 
-        if (vocab_only) {
+        if (params.vocab_only) {
             LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
             LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
             return true;
             return true;
         }
         }
 
 
         llm_load_tensors(
         llm_load_tensors(
-                ml, model, n_gpu_layers,
-                main_gpu, tensor_split,
-                use_mlock, progress_callback, progress_callback_user_data);
+            ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock,
+            params.progress_callback, params.progress_callback_user_data
+        );
     } catch (const std::exception & err) {
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
         LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
         return false;
         return false;
@@ -3150,6 +3199,7 @@ static struct ggml_tensor * llm_build_inp_embd(
 static void llm_build_k_shift(
 static void llm_build_k_shift(
       struct ggml_context * ctx,
       struct ggml_context * ctx,
       const llama_hparams & hparams,
       const llama_hparams & hparams,
+      const llama_cparams & cparams,
      const llama_kv_cache & kv,
      const llama_kv_cache & kv,
        struct ggml_cgraph * graph,
        struct ggml_cgraph * graph,
             llm_rope_type   type,
             llm_rope_type   type,
@@ -3162,6 +3212,11 @@ static void llm_build_k_shift(
     const int64_t n_head_kv   = hparams.n_head_kv;
     const int64_t n_head_kv   = hparams.n_head_kv;
     const int64_t n_embd_gqa  = hparams.n_embd_gqa();
     const int64_t n_embd_gqa  = hparams.n_embd_gqa();
     const int64_t n_embd_head = hparams.n_embd_head();
     const int64_t n_embd_head = hparams.n_embd_head();
+    const int32_t n_orig_ctx  = cparams.n_yarn_orig_ctx;
+    const float   ext_factor  = cparams.yarn_ext_factor;
+    const float   attn_factor = cparams.yarn_attn_factor;
+    const float   beta_fast   = cparams.yarn_beta_fast;
+    const float   beta_slow   = cparams.yarn_beta_slow;
 
 
     GGML_ASSERT(n_embd_head % n_rot == 0);
     GGML_ASSERT(n_embd_head % n_rot == 0);
 
 
@@ -3185,7 +3240,8 @@ static void llm_build_k_shift(
                         ggml_element_size(kv.k)*n_embd_head,
                         ggml_element_size(kv.k)*n_embd_head,
                         ggml_element_size(kv.k)*n_embd_gqa,
                         ggml_element_size(kv.k)*n_embd_gqa,
                         ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
                         ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
-                    K_shift, n_rot, rope_type, 0, freq_base, freq_scale);
+                    K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow);
         cb(tmp, "K_shifted", il);
         cb(tmp, "K_shifted", il);
         ggml_build_forward_expand(graph, tmp);
         ggml_build_forward_expand(graph, tmp);
     }
     }
@@ -3442,12 +3498,17 @@ struct llm_build_context {
 
 
     const float freq_base;
     const float freq_base;
     const float freq_scale;
     const float freq_scale;
+    const float ext_factor;
+    const float attn_factor;
+    const float beta_fast;
+    const float beta_slow;
     const float norm_eps;
     const float norm_eps;
     const float norm_rms_eps;
     const float norm_rms_eps;
 
 
     const int32_t n_tokens;
     const int32_t n_tokens;
     const int32_t n_kv;     // size of KV cache to consider (n_kv <= n_ctx)
     const int32_t n_kv;     // size of KV cache to consider (n_kv <= n_ctx)
     const int32_t kv_head;  // index of where we store new KV data in the cache
     const int32_t kv_head;  // index of where we store new KV data in the cache
+    const int32_t n_orig_ctx;
 
 
     const bool do_rope_shift;
     const bool do_rope_shift;
 
 
@@ -3477,11 +3538,16 @@ struct llm_build_context {
         n_embd_gqa    (hparams.n_embd_gqa()),
         n_embd_gqa    (hparams.n_embd_gqa()),
         freq_base     (cparams.rope_freq_base),
         freq_base     (cparams.rope_freq_base),
         freq_scale    (cparams.rope_freq_scale),
         freq_scale    (cparams.rope_freq_scale),
+        ext_factor    (cparams.yarn_ext_factor),
+        attn_factor   (cparams.yarn_attn_factor),
+        beta_fast     (cparams.yarn_beta_fast),
+        beta_slow     (cparams.yarn_beta_slow),
         norm_eps      (hparams.f_norm_eps),
         norm_eps      (hparams.f_norm_eps),
         norm_rms_eps  (hparams.f_norm_rms_eps),
         norm_rms_eps  (hparams.f_norm_rms_eps),
         n_tokens      (batch.n_tokens),
         n_tokens      (batch.n_tokens),
         n_kv          (worst_case ? n_ctx            : kv_self.n),
         n_kv          (worst_case ? n_ctx            : kv_self.n),
         kv_head       (worst_case ? n_ctx - n_tokens : kv_self.head),
         kv_head       (worst_case ? n_ctx - n_tokens : kv_self.head),
+        n_orig_ctx    (cparams.n_yarn_orig_ctx),
         do_rope_shift (worst_case || kv_self.has_shift),
         do_rope_shift (worst_case || kv_self.has_shift),
         cb            (cb),
         cb            (cb),
         buf_compute   (lctx.buf_compute) {
         buf_compute   (lctx.buf_compute) {
@@ -3532,7 +3598,7 @@ struct llm_build_context {
 
 
         // shift the entire K-cache if needed
         // shift the entire K-cache if needed
         if (do_rope_shift) {
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
         }
         }
 
 
         for (int il = 0; il < n_layer; ++il) {
         for (int il = 0; il < n_layer; ++il) {
@@ -3556,10 +3622,18 @@ struct llm_build_context {
                 struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
                 struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 cb(Vcur, "Vcur", il);
 
 
-                Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+                Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                    n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(Qcur, "Qcur", il);
                 cb(Qcur, "Qcur", il);
 
 
-                Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+                Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(Kcur, "Kcur", il);
                 cb(Kcur, "Kcur", il);
 
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
@@ -3634,7 +3708,7 @@ struct llm_build_context {
 
 
         // shift the entire K-cache if needed
         // shift the entire K-cache if needed
         if (do_rope_shift) {
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
         }
         }
 
 
         for (int il = 0; il < n_layer; ++il) {
         for (int il = 0; il < n_layer; ++il) {
@@ -3658,8 +3732,16 @@ struct llm_build_context {
 
 
                 switch (model.type) {
                 switch (model.type) {
                     case MODEL_7B:
                     case MODEL_7B:
-                        Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens),    inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
-                        Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+                        Qcur = ggml_rope_custom(
+                            ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                            n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                        Kcur = ggml_rope_custom(
+                            ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                            n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                        );
                         break;
                         break;
                     case MODEL_13B:
                     case MODEL_13B:
                         Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
                         Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
@@ -3746,7 +3828,7 @@ struct llm_build_context {
 
 
         // shift the entire K-cache if needed
         // shift the entire K-cache if needed
         if (do_rope_shift) {
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
         }
         }
 
 
         for (int il = 0; il < n_layer; ++il) {
         for (int il = 0; il < n_layer; ++il) {
@@ -3786,10 +3868,16 @@ struct llm_build_context {
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
 
                 // using mode = 2 for neox mode
                 // using mode = 2 for neox mode
-                Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
+                Qcur = ggml_rope_custom(
+                    ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(Qcur, "Qcur", il);
                 cb(Qcur, "Qcur", il);
 
 
-                Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
+                Kcur = ggml_rope_custom(
+                    ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(Kcur, "Kcur", il);
                 cb(Kcur, "Kcur", il);
 
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
@@ -3960,7 +4048,7 @@ struct llm_build_context {
         cb(KQ_mask, "KQ_mask", -1);
         cb(KQ_mask, "KQ_mask", -1);
 
 
         if (do_rope_shift) {
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
         }
         }
 
 
         for (int il = 0; il < n_layer; ++il) {
         for (int il = 0; il < n_layer; ++il) {
@@ -4053,13 +4141,15 @@ struct llm_build_context {
                 cb(kpass, "kpass", il);
                 cb(kpass, "kpass", il);
 
 
                 struct ggml_tensor * qrotated = ggml_rope_custom(
                 struct ggml_tensor * qrotated = ggml_rope_custom(
-                        ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
-                        );
+                    ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(qrotated, "qrotated", il);
                 cb(qrotated, "qrotated", il);
 
 
                 struct ggml_tensor * krotated = ggml_rope_custom(
                 struct ggml_tensor * krotated = ggml_rope_custom(
-                        ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
-                        );
+                    ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(krotated, "krotated", il);
                 cb(krotated, "krotated", il);
 
 
                 // ggml currently only supports concatenation on dim=2
                 // ggml currently only supports concatenation on dim=2
@@ -7883,8 +7973,13 @@ struct llama_context_params llama_context_default_params() {
         /*.n_batch                     =*/ 512,
         /*.n_batch                     =*/ 512,
         /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
         /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
         /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
         /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
+        /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_UNSPECIFIED,
         /*.rope_freq_base              =*/ 0.0f,
         /*.rope_freq_base              =*/ 0.0f,
         /*.rope_freq_scale             =*/ 0.0f,
         /*.rope_freq_scale             =*/ 0.0f,
+        /*.yarn_ext_factor             =*/ NAN,
+        /*.yarn_attn_factor            =*/ 1.0f,
+        /*.yarn_beta_fast              =*/ 32.0f,
+        /*.yarn_beta_slow              =*/ 1.0f,
         /*.mul_mat_q                   =*/ true,
         /*.mul_mat_q                   =*/ true,
         /*.f16_kv                      =*/ true,
         /*.f16_kv                      =*/ true,
         /*.logits_all                  =*/ false,
         /*.logits_all                  =*/ false,
@@ -7971,10 +8066,7 @@ struct llama_model * llama_load_model_from_file(
         };
         };
     }
     }
 
 
-    if (!llama_model_load(path_model, *model, params.n_gpu_layers,
-                params.main_gpu, params.tensor_split,
-                params.use_mmap, params.use_mlock, params.vocab_only,
-                params.progress_callback, params.progress_callback_user_data)) {
+    if (!llama_model_load(path_model, *model, params)) {
         LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
         LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
         delete model;
         delete model;
         return nullptr;
         return nullptr;
@@ -8000,13 +8092,35 @@ struct llama_context * llama_new_context_with_model(
     const auto & hparams = model->hparams;
     const auto & hparams = model->hparams;
     auto       & cparams = ctx->cparams;
     auto       & cparams = ctx->cparams;
 
 
-    cparams.n_batch         = params.n_batch;
-    cparams.n_ctx           = params.n_ctx == 0           ? hparams.n_ctx_train           : params.n_ctx;
-    cparams.rope_freq_base  = params.rope_freq_base == 0  ? hparams.rope_freq_base_train  : params.rope_freq_base;
-    cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale;
-    cparams.n_threads       = params.n_threads;
-    cparams.n_threads_batch = params.n_threads_batch;
-    cparams.mul_mat_q       = params.mul_mat_q;
+    cparams.n_batch          = params.n_batch;
+    cparams.n_threads        = params.n_threads;
+    cparams.n_threads_batch  = params.n_threads_batch;
+    cparams.yarn_ext_factor  = params.yarn_ext_factor;
+    cparams.yarn_attn_factor = params.yarn_attn_factor;
+    cparams.yarn_beta_fast   = params.yarn_beta_fast;
+    cparams.yarn_beta_slow   = params.yarn_beta_slow;
+    cparams.mul_mat_q        = params.mul_mat_q;
+
+    cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
+    cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
+    cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
+
+    cparams.n_yarn_orig_ctx  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
+                               hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
+                                                              hparams.n_ctx_train;
+
+    auto rope_scaling_type = params.rope_scaling_type;
+    if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
+        rope_scaling_type = hparams.rope_scaling_type_train;
+    }
+
+    if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) {
+        cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
+    }
+
+    if (std::isnan(cparams.yarn_ext_factor)) { // NaN indicates 'not set'
+        cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f;
+    }
 
 
     if (params.seed == LLAMA_DEFAULT_SEED) {
     if (params.seed == LLAMA_DEFAULT_SEED) {
         params.seed = time(NULL);
         params.seed = time(NULL);

+ 16 - 2
llama.h

@@ -106,6 +106,14 @@ extern "C" {
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
     };
     };
 
 
+    enum llama_rope_scaling_type {
+        LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
+        LLAMA_ROPE_SCALING_NONE        = 0,
+        LLAMA_ROPE_SCALING_LINEAR      = 1,
+        LLAMA_ROPE_SCALING_YARN        = 2,
+        LLAMA_ROPE_SCALING_MAX_VALUE   = LLAMA_ROPE_SCALING_YARN,
+    };
+
     typedef struct llama_token_data {
     typedef struct llama_token_data {
         llama_token id; // token id
         llama_token id; // token id
         float logit;    // log-odds of the token
         float logit;    // log-odds of the token
@@ -172,10 +180,16 @@ extern "C" {
         uint32_t n_batch;         // prompt processing maximum batch size
         uint32_t n_batch;         // prompt processing maximum batch size
         uint32_t n_threads;       // number of threads to use for generation
         uint32_t n_threads;       // number of threads to use for generation
         uint32_t n_threads_batch; // number of threads to use for batch processing
         uint32_t n_threads_batch; // number of threads to use for batch processing
+        int8_t   rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
 
 
         // ref: https://github.com/ggerganov/llama.cpp/pull/2054
         // ref: https://github.com/ggerganov/llama.cpp/pull/2054
-        float rope_freq_base;  // RoPE base frequency, 0 = from model
-        float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
+        float    rope_freq_base;   // RoPE base frequency, 0 = from model
+        float    rope_freq_scale;  // RoPE frequency scaling factor, 0 = from model
+        float    yarn_ext_factor;  // YaRN extrapolation mix factor, NaN = from model
+        float    yarn_attn_factor; // YaRN magnitude scaling factor
+        float    yarn_beta_fast;   // YaRN low correction dim
+        float    yarn_beta_slow;   // YaRN high correction dim
+        uint32_t yarn_orig_ctx;    // YaRN original context size
 
 
         // Keep the booleans together to avoid misalignment during copy-by-value.
         // Keep the booleans together to avoid misalignment during copy-by-value.
         bool mul_mat_q;  // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
         bool mul_mat_q;  // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)