|
|
@@ -578,6 +578,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
|
|
|
GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED);
|
|
|
|
|
|
+ // TODO: Handle SWA metadata similarly when models start implementing it
|
|
|
// rope_freq_scale (inverse of the kv) is optional
|
|
|
float ropescale = 0.0f;
|
|
|
if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) {
|
|
|
@@ -586,10 +587,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
}
|
|
|
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
|
|
|
|
|
|
- // by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers
|
|
|
- hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
|
|
- hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
|
|
|
-
|
|
|
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
|
|
|
|
|
|
// non-transformer models do not have attention heads
|
|
|
@@ -677,6 +674,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
hparams.f_attn_temp_scale = 0.1f;
|
|
|
hparams.f_attn_temp_offset = 1.0f;
|
|
|
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
|
|
+
|
|
|
+ hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
|
|
+ hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
|
|
|
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
|
|
}
|
|
|
|
|
|
switch (hparams.n_expert) {
|
|
|
@@ -722,6 +723,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
if (hparams.n_swa > 0) {
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
|
|
hparams.set_swa_pattern(4);
|
|
|
+
|
|
|
+ hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
|
|
+ hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
|
|
|
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
|
|
} else {
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
|
|
}
|
|
|
@@ -1243,7 +1248,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
if (found_swa && hparams.n_swa > 0) {
|
|
|
uint32_t swa_period = 8;
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
|
|
- hparams.rope_freq_scale_train_swa = 1.0f;
|
|
|
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
|
|
|
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false);
|
|
|
hparams.set_swa_pattern(swa_period);
|
|
|
@@ -1309,7 +1313,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
hparams.n_swa = 4096; // default value of gemma 2
|
|
|
hparams.set_swa_pattern(2);
|
|
|
hparams.attn_soft_cap = true;
|
|
|
+ hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
|
|
+ hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
|
|
|
|
|
|
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
|
|
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
|
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
|
|
|
@@ -1334,8 +1341,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
|
|
hparams.set_swa_pattern(6);
|
|
|
|
|
|
- hparams.rope_freq_base_train_swa = 10000.0f;
|
|
|
- hparams.rope_freq_scale_train_swa = 1.0f;
|
|
|
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
|
|
} else {
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
|
|
}
|
|
|
@@ -1365,10 +1371,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
hparams.set_swa_pattern(5);
|
|
|
|
|
|
hparams.n_layer_kv_from_start = 20;
|
|
|
- hparams.rope_freq_base_train_swa = 10000.0f;
|
|
|
- hparams.rope_freq_scale_train_swa = 1.0f;
|
|
|
hparams.f_attention_scale = 1.0f;
|
|
|
|
|
|
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
|
|
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
|
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
|
|
|
@@ -1384,9 +1389,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
hparams.set_swa_pattern(6);
|
|
|
|
|
|
hparams.causal_attn = false; // embeddings do not use causal attention
|
|
|
- hparams.rope_freq_base_train_swa = 10000.0f;
|
|
|
- hparams.rope_freq_scale_train_swa = 1.0f;
|
|
|
|
|
|
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
|
|
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
|
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
|
|
|
@@ -1525,7 +1529,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
{
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
|
|
hparams.set_swa_pattern(4);
|
|
|
+ hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
|
|
+ hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
|
|
|
|
|
|
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
|
|
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
|
|
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
|
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
|
|
@@ -1564,6 +1571,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
if (found_swa && hparams.n_swa > 0) {
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
|
|
hparams.set_swa_pattern(4);
|
|
|
+
|
|
|
+ hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
|
|
+ hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp
|
|
|
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
|
|
} else {
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
|
|
}
|
|
|
@@ -1906,6 +1917,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
|
|
hparams.n_swa = 4096;
|
|
|
hparams.set_swa_pattern(4);
|
|
|
+
|
|
|
+ hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
|
|
+ hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
|
|
|
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
|
|
}
|
|
|
|
|
|
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
|
|
@@ -2208,6 +2223,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
|
|
hparams.set_swa_pattern(2);
|
|
|
|
|
|
+ hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
|
|
+ hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
|
|
|
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
|
|
+
|
|
|
switch (hparams.n_layer) {
|
|
|
case 24: type = LLM_TYPE_20B; break;
|
|
|
case 36: type = LLM_TYPE_120B; break;
|
|
|
@@ -2252,6 +2271,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
|
|
hparams.n_swa = 4096;
|
|
|
hparams.set_swa_pattern(4, true);
|
|
|
+
|
|
|
+ hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
|
|
+ hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
|
|
|
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
|
|
} else {
|
|
|
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
|
|
hparams.n_no_rope_layer_step = hparams.n_layer;
|
|
|
@@ -7098,6 +7121,10 @@ void llama_model::print_info() const {
|
|
|
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
|
|
|
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
|
|
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
|
|
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
|
|
+ LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa);
|
|
|
+ LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa);
|
|
|
+ }
|
|
|
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
|
|
LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul);
|
|
|
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|