Explorar o código

graph: add f_attn_temp_offset (#18025)

Xuan-Son Nguyen hai 1 mes
pai
achega
0759b09c90
Modificáronse 4 ficheiros con 11 adicións e 4 borrados
  1. 2 2
      src/llama-graph.cpp
  2. 3 2
      src/llama-graph.h
  3. 1 0
      src/llama-hparams.h
  4. 5 0
      src/llama-model.cpp

+ 2 - 2
src/llama-graph.cpp

@@ -78,7 +78,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
         for (int i = 0; i < n_tokens; ++i) {
             const float pos = ubatch->pos[i];
             attn_scale_data[i] = std::log(
-                std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
+                std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
             ) * f_attn_temp_scale + 1.0;
         }
 
@@ -1203,7 +1203,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
-    auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
+    auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
 
     auto & cur = inp->attn_scale;
 

+ 3 - 2
src/llama-graph.h

@@ -132,8 +132,8 @@ public:
 // temperature tuning, used by llama4
 class llm_graph_input_attn_temp : public llm_graph_input_i {
 public:
-    llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
-        : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
+    llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
+        : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
     virtual ~llm_graph_input_attn_temp() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
@@ -142,6 +142,7 @@ public:
 
     const uint32_t n_attn_temp_floor_scale;
     const float    f_attn_temp_scale;
+    const float    f_attn_temp_offset;
 };
 
 class llm_graph_input_pos_bucket : public llm_graph_input_i {

+ 1 - 0
src/llama-hparams.h

@@ -165,6 +165,7 @@ struct llama_hparams {
     uint32_t n_no_rope_layer_step    = 4;
     uint32_t n_attn_temp_floor_scale = 0;
     float    f_attn_temp_scale       = 0.0f;
+    float    f_attn_temp_offset      = 0.0f; // offset position index
 
     // gemma3n altup
     uint32_t n_altup      = 4; // altup_num_inputs

+ 5 - 0
src/llama-model.cpp

@@ -668,6 +668,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.n_swa                   = 8192;
                     hparams.n_attn_temp_floor_scale = 8192;
                     hparams.f_attn_temp_scale       = 0.1f;
+                    hparams.f_attn_temp_offset      = 1.0f;
                     hparams.set_swa_pattern(4);   // pattern: 3 chunked - 1 full
                 }
 
@@ -1646,6 +1647,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE,  hparams.f_attn_temp_scale,       false);
                 ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false);
 
+                hparams.f_attn_temp_offset = 0.0f;
+
                 switch (hparams.n_layer) {
                     case 27: type = LLM_TYPE_16B; break;
                     case 60: type = LLM_TYPE_236B; break;
@@ -2276,6 +2279,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow,    false);
                 ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL,   hparams.rope_yarn_log_mul, 0.0f);
 
+                hparams.f_attn_temp_offset = 0.0f;
+
                 // TODO: maybe add n_attn_temp_floor_scale as a separate KV?
                 if (hparams.f_attn_temp_scale != 0.0f) {
                     hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;