Преглед изворни кода

memory : handle kv_unified for hybrid models (#15050)

compilade пре 5 месеци
родитељ
комит
11a3811164
3 измењених фајлова са 4 додато и 1 уклоњено
  1. 2 1
      src/llama-memory-hybrid.cpp
  2. 1 0
      src/llama-memory-hybrid.h
  3. 1 0
      src/llama-model.cpp

+ 2 - 1
src/llama-memory-hybrid.cpp

@@ -25,6 +25,7 @@ llama_memory_hybrid::llama_memory_hybrid(
                          /* common */
                          /* common */
              uint32_t    n_seq_max,
              uint32_t    n_seq_max,
                  bool    offload,
                  bool    offload,
+                 bool    unified,
                          /* layer filters */
                          /* layer filters */
       layer_filter_cb && filter_attn,
       layer_filter_cb && filter_attn,
       layer_filter_cb && filter_recr) :
       layer_filter_cb && filter_recr) :
@@ -38,7 +39,7 @@ llama_memory_hybrid::llama_memory_hybrid(
         type_v,
         type_v,
         v_trans,
         v_trans,
         offload,
         offload,
-        1,
+        unified,
         kv_size,
         kv_size,
         n_seq_max,
         n_seq_max,
         n_pad,
         n_pad,

+ 1 - 0
src/llama-memory-hybrid.h

@@ -39,6 +39,7 @@ public:
                              /* common */
                              /* common */
                  uint32_t    n_seq_max,
                  uint32_t    n_seq_max,
                      bool    offload,
                      bool    offload,
+                     bool    unified,
                              /* layer filters */
                              /* layer filters */
           layer_filter_cb && filter_attn = nullptr,
           layer_filter_cb && filter_attn = nullptr,
           layer_filter_cb && filter_recr = nullptr);
           layer_filter_cb && filter_recr = nullptr);

+ 1 - 0
src/llama-model.cpp

@@ -17598,6 +17598,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
                         /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
                         /* n_seq_max         */ cparams.n_seq_max,
                         /* n_seq_max         */ cparams.n_seq_max,
                         /* offload           */ cparams.offload_kqv,
                         /* offload           */ cparams.offload_kqv,
+                        /* unified           */ cparams.kv_unified,
                         /* filter_attn       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
                         /* filter_attn       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
                         /* filter_recr       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
                         /* filter_recr       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
                 } else {
                 } else {