ソースを参照

Respect tokenizer.ggml.add_bos_token value when tokenizing (#4040)

* gguf-py: gguf-dump: Respect --no-tensor flag in JSON mode.

* Respect add_bos_token GGUF metadata value

* gguf-py: Try to fix SpecialVocab giving up too easily for the Nth time
Kerfuffle 2 年 前
コミット
91f6499393

+ 6 - 0
common/common.cpp

@@ -1072,6 +1072,12 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
     return result;
 }
 
+bool llama_should_add_bos_token(const llama_model * model) {
+    const int add_bos = llama_add_bos_token(model);
+
+    return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
+}
+
 //
 // YAML utils
 //

+ 4 - 0
common/common.h

@@ -200,6 +200,10 @@ std::string llama_detokenize_bpe(
                          llama_context * ctx,
         const std::vector<llama_token> & tokens);
 
+// Uses the value from the model metadata if possible, otherwise
+// defaults to true when model type is SPM, otherwise false.
+bool llama_should_add_bos_token(const llama_model * model);
+
 //
 // YAML utils
 //

+ 1 - 1
examples/infill/infill.cpp

@@ -230,7 +230,7 @@ int main(int argc, char ** argv) {
         LOG_TEE("\n");
         LOG_TEE("%s\n", get_system_info(params).c_str());
     }
-    const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
+    const bool add_bos = llama_should_add_bos_token(model);
     LOG("add_bos: %d\n", add_bos);
 
     bool suff_rm_leading_spc = params.escape;

+ 2 - 1
examples/llava/llava-cli.cpp

@@ -208,9 +208,10 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
     int n_past = 0;
 
     const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
+    const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx_llava->ctx_llama));
 
     // llava chat format is "<system_prompt>\nUSER:<image_embeddings>\n<textual_prompt>\nASSISTANT:"
-    eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant.  The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params->n_batch, &n_past, true);
+    eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant.  The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params->n_batch, &n_past, add_bos);
     llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
     eval_string(ctx_llava->ctx_llama, (prompt + "\nASSISTANT:").c_str(), params->n_batch, &n_past, false);
 

+ 1 - 1
examples/main/main.cpp

@@ -229,7 +229,7 @@ int main(int argc, char ** argv) {
         }
     }
 
-    const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
+    const bool add_bos = llama_should_add_bos_token(model);
     LOG("add_bos: %d\n", add_bos);
 
     std::vector<llama_token> embd_inp;

+ 3 - 5
examples/perplexity/perplexity.cpp

@@ -149,8 +149,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
     // Output: `perplexity: 13.5106 [114/114]`
     // BOS tokens will be added for each chunk before eval
 
-    const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
-    const bool add_bos = is_spm;
+    const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
 
     fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
 
@@ -288,8 +287,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     // Output: `perplexity: 13.5106 [114/114]`
     // BOS tokens will be added for each chunk before eval
 
-    const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
-    const bool add_bos = is_spm;
+    const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
     const int n_ctx = llama_n_ctx(ctx);
 
     auto tim1 = std::chrono::high_resolution_clock::now();
@@ -481,7 +479,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     fprintf(stderr, "================================= is_spm = %d\n", is_spm);
 
     // This is needed as usual for LLaMA models
-    const bool add_bos = is_spm;
+    const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
 
     // Number of tasks to use when computing the score
     if ( params.hellaswag_tasks < hs_task_count  ) {

+ 6 - 3
examples/server/server.cpp

@@ -501,6 +501,7 @@ struct llama_server_context
     bool multimodal         = false;
     bool clean_kv_cache     = true;
     bool all_slots_are_idle = false;
+    bool add_bos_token      = true;
 
     int32_t id_gen;
     int32_t n_ctx;  // total context for all clients / slots
@@ -573,6 +574,8 @@ struct llama_server_context
 
         n_ctx = llama_n_ctx(ctx);
 
+        add_bos_token = llama_should_add_bos_token(model);
+
         return true;
     }
 
@@ -864,7 +867,7 @@ struct llama_server_context
     }
 
     void update_system_prompt() {
-        system_tokens = ::llama_tokenize(ctx, system_prompt, true);
+        system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
 
         llama_batch_clear(batch);
 
@@ -1552,7 +1555,7 @@ struct llama_server_context
                     }
                     else
                     {
-                        prompt_tokens = tokenize(slot.prompt, system_prompt.empty());  // add BOS if there isn't system prompt
+                        prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token);  // add BOS if there isn't system prompt
                     }
 
                     slot.num_prompt_tokens = prompt_tokens.size();
@@ -1629,7 +1632,7 @@ struct llama_server_context
                     const bool has_images = process_images(slot);
 
                     // process the prefix of first image
-                    std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens;
+                    std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
                     for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
                     {
                        llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false);

+ 15 - 10
gguf-py/gguf/vocab.py

@@ -117,17 +117,18 @@ class SpecialVocab:
 
     def _try_load_from_tokenizer_json(self, path: Path) -> bool:
         tokenizer_file = path / 'tokenizer.json'
-        if not tokenizer_file.is_file():
-            return False
-        with open(tokenizer_file, encoding = 'utf-8') as f:
-            tokenizer = json.load(f)
-        if self.load_merges:
-            merges = tokenizer.get('model', {}).get('merges')
-            if isinstance(merges, list) and merges and isinstance(merges[0], str):
-                self.merges = merges
+        if tokenizer_file.is_file():
+            with open(tokenizer_file, encoding = 'utf-8') as f:
+                tokenizer = json.load(f)
+            if self.load_merges:
+                merges = tokenizer.get('model', {}).get('merges')
+                if isinstance(merges, list) and merges and isinstance(merges[0], str):
+                    self.merges = merges
+            added_tokens = tokenizer.get('added_tokens', {})
+        else:
+            added_tokens = {}
         tokenizer_config_file = path / 'tokenizer_config.json'
-        added_tokens = tokenizer.get('added_tokens')
-        if added_tokens is None or not tokenizer_config_file.is_file():
+        if not tokenizer_config_file.is_file():
             return True
         with open(tokenizer_config_file, encoding = 'utf-8') as f:
             tokenizer_config = json.load(f)
@@ -135,6 +136,10 @@ class SpecialVocab:
             add_entry = tokenizer_config.get(f'add_{typ}_token')
             if isinstance(add_entry, bool):
                 self.add_special_token[typ] = add_entry
+            if not added_tokens:
+                # We will need this to get the content for the token, so if it's empty
+                # may as well just give up.
+                continue
             entry = tokenizer_config.get(f'{typ}_token')
             if isinstance(entry, str):
                 tc_content = entry

+ 1 - 1
gguf-py/pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "gguf"
-version = "0.5.2"
+version = "0.5.3"
 description = "Read and write ML models in GGUF for GGML"
 authors = ["GGML <ggml@ggml.ai>"]
 packages = [

+ 8 - 7
gguf-py/scripts/gguf-dump.py

@@ -86,13 +86,14 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None:
             curr["value"] = str(bytes(field.parts[-1]), encoding="utf-8")
         else:
             curr["value"] = field.parts[-1].tolist()[0]
-    for idx, tensor in enumerate(reader.tensors):
-        tensors[tensor.name] = {
-            "index": idx,
-            "shape": tensor.shape.tolist(),
-            "type": tensor.tensor_type.name,
-            "offset": tensor.field.offset,
-        }
+    if not args.no_tensors:
+        for idx, tensor in enumerate(reader.tensors):
+            tensors[tensor.name] = {
+                "index": idx,
+                "shape": tensor.shape.tolist(),
+                "type": tensor.tensor_type.name,
+                "offset": tensor.field.offset,
+            }
     json.dump(result, sys.stdout)
 
 

+ 32 - 0
llama.cpp

@@ -255,6 +255,8 @@ enum llm_kv {
     LLM_KV_TOKENIZER_UNK_ID,
     LLM_KV_TOKENIZER_SEP_ID,
     LLM_KV_TOKENIZER_PAD_ID,
+    LLM_KV_TOKENIZER_ADD_BOS,
+    LLM_KV_TOKENIZER_ADD_EOS,
     LLM_KV_TOKENIZER_HF_JSON,
     LLM_KV_TOKENIZER_RWKV,
 };
@@ -303,6 +305,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_UNK_ID,              "tokenizer.ggml.unknown_token_id"   },
     { LLM_KV_TOKENIZER_SEP_ID,              "tokenizer.ggml.seperator_token_id" },
     { LLM_KV_TOKENIZER_PAD_ID,              "tokenizer.ggml.padding_token_id"   },
+    { LLM_KV_TOKENIZER_ADD_BOS,             "tokenizer.ggml.add_bos_token"      },
+    { LLM_KV_TOKENIZER_ADD_EOS,             "tokenizer.ggml.add_eos_token"      },
     { LLM_KV_TOKENIZER_HF_JSON,             "tokenizer.huggingface.json"        },
     { LLM_KV_TOKENIZER_RWKV,                "tokenizer.rwkv.world"              },
 };
@@ -1276,6 +1280,9 @@ struct llama_vocab {
     id special_sep_id = -1;
     id special_pad_id = -1;
 
+    int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
+    int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
+
     id linefeed_id       = 13;
     id special_prefix_id = 32007;
     id special_middle_id = 32009;
@@ -2388,6 +2395,23 @@ static void llm_load_vocab(
                     __func__, key.c_str(), id, old_id);
                 id = old_id;
             }
+
+        }
+
+        // Handle add_bos_token and add_eos_token
+        std::string key = kv(LLM_KV_TOKENIZER_ADD_BOS);
+        int kid = gguf_find_key(ctx, key.c_str());
+        enum gguf_type ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
+        vocab.special_add_bos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
+        if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
+            LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
+        }
+        key = kv(LLM_KV_TOKENIZER_ADD_EOS);
+        kid = gguf_find_key(ctx, key.c_str());
+        ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
+        vocab.special_add_eos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
+        if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
+            LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
         }
     }
 
@@ -9288,6 +9312,14 @@ llama_token llama_token_nl(const struct llama_model * model) {
     return model->vocab.linefeed_id;
 }
 
+int llama_add_bos_token(const struct llama_model * model) {
+    return model->vocab.special_add_bos;
+}
+
+int llama_add_eos_token(const struct llama_model * model) {
+    return model->vocab.special_add_eos;
+}
+
 llama_token llama_token_prefix(const struct llama_model * model) {
     return model->vocab.special_prefix_id;
 }

+ 6 - 0
llama.h

@@ -517,6 +517,12 @@ extern "C" {
     LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
     LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
 
+    // Returns -1 if unknown, 1 for true or 0 for false.
+    LLAMA_API int         llama_add_bos_token(const struct llama_model * model);
+
+    // Returns -1 if unknown, 1 for true or 0 for false.
+    LLAMA_API int         llama_add_eos_token(const struct llama_model * model);
+
     // codellama infill tokens
     LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
     LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle