Просмотр исходного кода

llama : fix memory leak in llama_batch_free (#5252)

The llama_batch_init allocates memory for a fixed number of tokens.
However, the llama_batch_free only frees memory for the number of
tokens that were added to the batch.

This change-set uses a null terminated array for the batch seq_id, and
frees all the elements until the nullptr is reached. This change-set
also changes the name of the first parameter from `n_tokens` to
`n_tokens_alloc` to more clearly indicate that this value is the number
of tokens allocated to the batch, not the number of tokens in the batch.
Ian Bull 1 год назад
Родитель
Сommit
e1e721094d
1 измененных файлов с 11 добавлено и 9 удалено
  1. 11 9
      llama.cpp

+ 11 - 9
llama.cpp

@@ -11377,22 +11377,24 @@ struct llama_batch llama_batch_get_one(
     };
     };
 }
 }
 
 
-struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
+struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
     llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
     llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
 
 
     if (embd) {
     if (embd) {
-        batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
+        batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
     } else {
     } else {
-        batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
+        batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
     }
     }
 
 
-    batch.pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens);
-    batch.n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens);
-    batch.seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
-    for (int i = 0; i < n_tokens; ++i) {
+    batch.pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens_alloc);
+    batch.n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens_alloc);
+    batch.seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
+    for (int i = 0; i < n_tokens_alloc; ++i) {
         batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
         batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
     }
     }
-    batch.logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens);
+    batch.seq_id[n_tokens_alloc] = nullptr;
+
+    batch.logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens_alloc);
 
 
     return batch;
     return batch;
 }
 }
@@ -11403,7 +11405,7 @@ void llama_batch_free(struct llama_batch batch) {
     if (batch.pos)      free(batch.pos);
     if (batch.pos)      free(batch.pos);
     if (batch.n_seq_id) free(batch.n_seq_id);
     if (batch.n_seq_id) free(batch.n_seq_id);
     if (batch.seq_id) {
     if (batch.seq_id) {
-        for (int i = 0; i < batch.n_tokens; ++i) {
+        for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
             free(batch.seq_id[i]);
             free(batch.seq_id[i]);
         }
         }
         free(batch.seq_id);
         free(batch.seq_id);