Просмотр исходного кода

llama : sanitize invalid tokens (#9357)

* common : do not add null tokens during warmup

ggml-ci

* llama : check that the input tokens are valid

ggml-ci

* tests : fix batch size of bert model

ggml-ci
Georgi Gerganov 1 год назад
Родитель
Сommit
faf69d4237
3 измененных файлов с 26 добавлено и 4 удалено
  1. 7 2
      common/common.cpp
  2. 5 2
      examples/server/tests/features/embeddings.feature
  3. 14 0
      src/llama.cpp

+ 7 - 2
common/common.cpp

@@ -2690,10 +2690,15 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
         llama_token bos = llama_token_bos(model);
         llama_token bos = llama_token_bos(model);
         llama_token eos = llama_token_eos(model);
         llama_token eos = llama_token_eos(model);
         // some models (e.g. T5) don't have a BOS token
         // some models (e.g. T5) don't have a BOS token
-        if (bos != -1) {
+        if (bos != LLAMA_TOKEN_NULL) {
             tmp.push_back(bos);
             tmp.push_back(bos);
         }
         }
-        tmp.push_back(eos);
+        if (eos != LLAMA_TOKEN_NULL) {
+            tmp.push_back(eos);
+        }
+        if (tmp.empty()) {
+            tmp.push_back(0);
+        }
 
 
         if (llama_model_has_encoder(model)) {
         if (llama_model_has_encoder(model)) {
             llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
             llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));

+ 5 - 2
examples/server/tests/features/embeddings.feature

@@ -9,8 +9,11 @@ Feature: llama.cpp server
     And   a model alias bert-bge-small
     And   a model alias bert-bge-small
     And   42 as server seed
     And   42 as server seed
     And   2 slots
     And   2 slots
-    And   1024 as batch size
-    And   1024 as ubatch size
+    # the bert-bge-small model has context size of 512
+    # since the generated prompts are as big as the batch size, we need to set the batch size to 512
+    # ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
+    And   512 as batch size
+    And   512 as ubatch size
     And   2048 KV cache size
     And   2048 KV cache size
     And   embeddings extraction
     And   embeddings extraction
     Then  the server is starting
     Then  the server is starting

+ 14 - 0
src/llama.cpp

@@ -16066,6 +16066,13 @@ static int llama_decode_internal(
         return -1;
         return -1;
     }
     }
 
 
+    for (uint32_t i = 0; i < n_tokens_all; ++i) {
+        if (batch_all.token[i] < 0) {
+            LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch_all.token[i]);
+            return -1;
+        }
+    }
+
     const auto & model   = lctx.model;
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
     const auto & cparams = lctx.cparams;
@@ -16358,6 +16365,13 @@ static int llama_encode_internal(
         return -1;
         return -1;
     }
     }
 
 
+    for (uint32_t i = 0; i < n_tokens; ++i) {
+        if (batch.token[i] < 0) {
+            LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch.token[i]);
+            return -1;
+        }
+    }
+
     const auto & model   = lctx.model;
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
     const auto & cparams = lctx.cparams;