Переглянути джерело

llama : allow getting n_batch from llama_context in c api (#4540)

* allowed getting n_batch from llama_context in c api

* changed to use `uint32_t` instead of `int`

* changed to use `uint32_t` instead of `int` in `llama_n_ctx`

* Update llama.h

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Marcus Dunn 2 роки тому
батько
коміт
31f27758fa
2 змінених файлів з 8 додано та 2 видалено
  1. 5 1
      llama.cpp
  2. 3 1
      llama.h

+ 5 - 1
llama.cpp

@@ -9532,10 +9532,14 @@ const llama_model * llama_get_model(const struct llama_context * ctx) {
     return &ctx->model;
 }
 
-int llama_n_ctx(const struct llama_context * ctx) {
+uint32_t llama_n_ctx(const struct llama_context * ctx) {
     return ctx->cparams.n_ctx;
 }
 
+uint32_t llama_n_batch(const struct llama_context * ctx) {
+    return ctx->cparams.n_batch;
+}
+
 enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
     return model->vocab.type;
 }

+ 3 - 1
llama.h

@@ -314,7 +314,9 @@ extern "C" {
 
     LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
 
-    LLAMA_API int llama_n_ctx      (const struct llama_context * ctx);
+    // TODO: become more consistent with returned int types across the API
+    LLAMA_API uint32_t llama_n_ctx      (const struct llama_context * ctx);
+    LLAMA_API uint32_t llama_n_batch    (const struct llama_context * ctx);
 
     LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);