Просмотр исходного кода

llama : expose llama_model_n_head_kv in the API (#11997)

It's useful to be able to have this from the library layer as it's a key
parameter of the model (e.g. to figure out how much KV cache memory is
needed).
Vitali Lovich 10 месяцев назад
Родитель
Сommit
3e9a2860e9
2 измененных файлов с 5 добавлено и 0 удалено
  1. 1 0
      include/llama.h
  2. 4 0
      src/llama-model.cpp

+ 1 - 0
include/llama.h

@@ -477,6 +477,7 @@ extern "C" {
     LLAMA_API int32_t llama_model_n_embd     (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_layer    (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head     (const struct llama_model * model);
+    LLAMA_API int32_t llama_model_n_head_kv  (const struct llama_model * model);
 
     // Get the model's RoPE frequency scaling factor
     LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

+ 4 - 0
src/llama-model.cpp

@@ -3838,6 +3838,10 @@ int32_t llama_model_n_head(const struct llama_model * model) {
     return model->hparams.n_head();
 }
 
+int32_t llama_model_n_head_kv(const struct llama_model * model) {
+    return model->hparams.n_head_kv();
+}
+
 // deprecated
 int32_t llama_n_ctx_train(const struct llama_model * model) {
     return llama_model_n_ctx_train(model);