Просмотр исходного кода

Stop the generation when <|eom_id|> token is encountered - needed for Llama 3.1 tool call support (#8858)

* gguf-py, llama : add constants and methods related to Llama-3.1 <|eom_id|> token

* llama : find Llama-3.1 <|eom_id|> token id during vocab loading

* llama-vocab : add Llama-3.1 <|eom_id|> token to the set of tokens stopping the generation

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
fairydreaming 1 год назад
Родитель
Сommit
d3f0c7166a
5 измененных файлов с 27 добавлено и 1 удалено
  1. 2 0
      gguf-py/gguf/constants.py
  2. 3 0
      gguf-py/gguf/gguf_writer.py
  3. 6 1
      src/llama-vocab.cpp
  4. 2 0
      src/llama-vocab.h
  5. 14 0
      src/llama.cpp

+ 2 - 0
gguf-py/gguf/constants.py

@@ -161,6 +161,7 @@ class Keys:
         SUFFIX_ID            = "tokenizer.ggml.suffix_token_id"
         MIDDLE_ID            = "tokenizer.ggml.middle_token_id"
         EOT_ID               = "tokenizer.ggml.eot_token_id"
+        EOM_ID               = "tokenizer.ggml.eom_token_id"
 
     class Adapter:
         TYPE       = "adapter.type"
@@ -1327,3 +1328,4 @@ KEY_TOKENIZER_PRIFIX_ID  = Keys.Tokenizer.PREFIX_ID
 KEY_TOKENIZER_SUFFIX_ID  = Keys.Tokenizer.SUFFIX_ID
 KEY_TOKENIZER_MIDDLE_ID  = Keys.Tokenizer.MIDDLE_ID
 KEY_TOKENIZER_EOT_ID     = Keys.Tokenizer.EOT_ID
+KEY_TOKENIZER_EOM_ID     = Keys.Tokenizer.EOM_ID

+ 3 - 0
gguf-py/gguf/gguf_writer.py

@@ -828,6 +828,9 @@ class GGUFWriter:
     def add_eot_token_id(self, id: int) -> None:
         self.add_uint32(Keys.Tokenizer.EOT_ID, id)
 
+    def add_eom_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.EOM_ID, id)
+
     def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
         pack_prefix = ''
         if not skip_pack_prefix:

+ 6 - 1
src/llama-vocab.cpp

@@ -1444,7 +1444,8 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
 bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
     return token != -1 && (
         token == llama_token_eos_impl(vocab) ||
-        token == llama_token_eot_impl(vocab)
+        token == llama_token_eot_impl(vocab) ||
+        token == llama_token_eom_impl(vocab)
     );
 }
 
@@ -1500,6 +1501,10 @@ llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
     return vocab.special_eot_id;
 }
 
+llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
+    return vocab.special_eom_id;
+}
+
 int32_t llama_tokenize_impl(
     const struct llama_vocab & vocab,
                   const char * text,

+ 2 - 0
src/llama-vocab.h

@@ -45,6 +45,7 @@ struct llama_vocab {
     id special_suffix_id = -1;
     id special_middle_id = -1;
     id special_eot_id    = -1; // TODO: move above after "eos_id", and here add "file separator" token
+    id special_eom_id    = -1;
 
     // tokenizer flags
     bool tokenizer_add_space_prefix = false;
@@ -101,6 +102,7 @@ llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
 llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
 llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
 llama_token llama_token_eot_impl   (const struct llama_vocab & vocab);
+llama_token llama_token_eom_impl   (const struct llama_vocab & vocab);
 
 int32_t llama_tokenize_impl(
         const struct llama_vocab & vocab,

+ 14 - 0
src/llama.cpp

@@ -359,6 +359,7 @@ enum llm_kv {
     LLM_KV_TOKENIZER_SUFFIX_ID,
     LLM_KV_TOKENIZER_MIDDLE_ID,
     LLM_KV_TOKENIZER_EOT_ID,
+    LLM_KV_TOKENIZER_EOM_ID,
 
     LLM_KV_ADAPTER_TYPE,
     LLM_KV_ADAPTER_LORA_ALPHA,
@@ -456,6 +457,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_SUFFIX_ID,            "tokenizer.ggml.suffix_token_id"          },
     { LLM_KV_TOKENIZER_MIDDLE_ID,            "tokenizer.ggml.middle_token_id"          },
     { LLM_KV_TOKENIZER_EOT_ID,               "tokenizer.ggml.eot_token_id"             },
+    { LLM_KV_TOKENIZER_EOM_ID,               "tokenizer.ggml.eom_token_id"             },
 
     { LLM_KV_ADAPTER_TYPE,                  "adapter.type"       },
     { LLM_KV_ADAPTER_LORA_ALPHA,            "adapter.lora.alpha" },
@@ -5583,6 +5585,7 @@ static void llm_load_vocab(
             { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
             { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
             { LLM_KV_TOKENIZER_EOT_ID,    vocab.special_eot_id    },
+            { LLM_KV_TOKENIZER_EOM_ID,    vocab.special_eom_id    },
         };
 
         for (const auto & it : special_token_types) {
@@ -5635,6 +5638,17 @@ static void llm_load_vocab(
                 }
             }
         }
+
+        // find EOM token: "<|eom_id|>"
+        //
+        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
+        //       for now, we apply this workaround to find the EOM token based on its text
+        if (vocab.special_eom_id == -1) {
+            const auto & t = vocab.token_to_id.find("<|eom_id|>");
+            if (t != vocab.token_to_id.end()) {
+                vocab.special_eom_id = t->second;
+            }
+        }
     }
 
     // build special tokens cache