Przeglądaj źródła

gguf : add special tokens metadata for FIM/Infill (#6689)

This commit adds special token metadata for Fill-In-the-Middle
(FIM)/Infill to the GGUF model.

The motivation for this is that currently there is support for CodeLlama
but other models exist now like CodeGemma, but the different models use
different token ids for the special tokens and this commit allows for
supporting multiple models.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
Daniel Bevenius 1 rok temu
rodzic
commit
4fbd8098e6
4 zmienionych plików z 83 dodań i 11 usunięć
  1. 15 0
      convert-hf-to-gguf.py
  2. 9 0
      gguf-py/gguf/constants.py
  3. 12 0
      gguf-py/gguf/gguf_writer.py
  4. 47 11
      llama.cpp

+ 15 - 0
convert-hf-to-gguf.py

@@ -1221,6 +1221,14 @@ class LlamaModel(Model):
         except FileNotFoundError:
         except FileNotFoundError:
             self._set_vocab_llama_hf()
             self._set_vocab_llama_hf()
 
 
+        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
+                                          special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
+        special_vocab._set_special_token("prefix", 32007)
+        special_vocab._set_special_token("suffix", 32008)
+        special_vocab._set_special_token("middle", 32009)
+        special_vocab._set_special_token("eot",    32010)
+        special_vocab.add_to_gguf(self.gguf_writer)
+
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
         super().set_gguf_parameters()
         hparams = self.hparams
         hparams = self.hparams
@@ -2240,6 +2248,13 @@ class GemmaModel(Model):
 
 
     def set_vocab(self):
     def set_vocab(self):
         self._set_vocab_sentencepiece()
         self._set_vocab_sentencepiece()
+        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
+                                          special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
+        special_vocab._set_special_token("prefix", 67)
+        special_vocab._set_special_token("suffix", 69)
+        special_vocab._set_special_token("middle", 68)
+        special_vocab._set_special_token("eot",    70)
+        special_vocab.add_to_gguf(self.gguf_writer)
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
         hparams = self.hparams
         hparams = self.hparams

+ 9 - 0
gguf-py/gguf/constants.py

@@ -90,6 +90,11 @@ class Keys:
         HF_JSON          = "tokenizer.huggingface.json"
         HF_JSON          = "tokenizer.huggingface.json"
         RWKV             = "tokenizer.rwkv.world"
         RWKV             = "tokenizer.rwkv.world"
         CHAT_TEMPLATE    = "tokenizer.chat_template"
         CHAT_TEMPLATE    = "tokenizer.chat_template"
+        # FIM/Infill special tokens constants
+        PREFIX_ID        = "tokenizer.ggml.prefix_token_id"
+        SUFFIX_ID        = "tokenizer.ggml.suffix_token_id"
+        MIDDLE_ID        = "tokenizer.ggml.middle_token_id"
+        EOT_ID           = "tokenizer.ggml.eot_token_id"
 
 
 
 
 #
 #
@@ -885,3 +890,7 @@ KEY_TOKENIZER_CLS_ID     = Keys.Tokenizer.CLS_ID
 KEY_TOKENIZER_MASK_ID    = Keys.Tokenizer.MASK_ID
 KEY_TOKENIZER_MASK_ID    = Keys.Tokenizer.MASK_ID
 KEY_TOKENIZER_HF_JSON    = Keys.Tokenizer.HF_JSON
 KEY_TOKENIZER_HF_JSON    = Keys.Tokenizer.HF_JSON
 KEY_TOKENIZER_RWKV       = Keys.Tokenizer.RWKV
 KEY_TOKENIZER_RWKV       = Keys.Tokenizer.RWKV
+KEY_TOKENIZER_PRIFIX_ID  = Keys.Tokenizer.PREFIX_ID
+KEY_TOKENIZER_SUFFIX_ID  = Keys.Tokenizer.SUFFIX_ID
+KEY_TOKENIZER_MIDDLE_ID  = Keys.Tokenizer.MIDDLE_ID
+KEY_TOKENIZER_EOT_ID     = Keys.Tokenizer.EOT_ID

+ 12 - 0
gguf-py/gguf/gguf_writer.py

@@ -469,6 +469,18 @@ class GGUFWriter:
     def add_chat_template(self, value: str) -> None:
     def add_chat_template(self, value: str) -> None:
         self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
         self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
 
 
+    def add_prefix_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.PREFIX_ID, id)
+
+    def add_suffix_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.SUFFIX_ID, id)
+
+    def add_middle_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.MIDDLE_ID, id)
+
+    def add_eot_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.EOT_ID, id)
+
     def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
     def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
         pack_prefix = ''
         pack_prefix = ''
         if not skip_pack_prefix:
         if not skip_pack_prefix:

+ 47 - 11
llama.cpp

@@ -327,6 +327,10 @@ enum llm_kv {
     LLM_KV_TOKENIZER_ADD_PREFIX,
     LLM_KV_TOKENIZER_ADD_PREFIX,
     LLM_KV_TOKENIZER_HF_JSON,
     LLM_KV_TOKENIZER_HF_JSON,
     LLM_KV_TOKENIZER_RWKV,
     LLM_KV_TOKENIZER_RWKV,
+    LLM_KV_TOKENIZER_PREFIX_ID,
+    LLM_KV_TOKENIZER_SUFFIX_ID,
+    LLM_KV_TOKENIZER_MIDDLE_ID,
+    LLM_KV_TOKENIZER_EOT_ID,
 };
 };
 
 
 static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
 static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
@@ -399,6 +403,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_ADD_PREFIX,          "tokenizer.ggml.add_space_prefix"   },
     { LLM_KV_TOKENIZER_ADD_PREFIX,          "tokenizer.ggml.add_space_prefix"   },
     { LLM_KV_TOKENIZER_HF_JSON,             "tokenizer.huggingface.json"        },
     { LLM_KV_TOKENIZER_HF_JSON,             "tokenizer.huggingface.json"        },
     { LLM_KV_TOKENIZER_RWKV,                "tokenizer.rwkv.world"              },
     { LLM_KV_TOKENIZER_RWKV,                "tokenizer.rwkv.world"              },
+    { LLM_KV_TOKENIZER_PREFIX_ID,           "tokenizer.ggml.prefix_token_id"    },
+    { LLM_KV_TOKENIZER_SUFFIX_ID,           "tokenizer.ggml.suffix_token_id"    },
+    { LLM_KV_TOKENIZER_MIDDLE_ID,           "tokenizer.ggml.middle_token_id"    },
+    { LLM_KV_TOKENIZER_EOT_ID,              "tokenizer.ggml.eot_token_id"       },
 };
 };
 
 
 struct LLM_KV {
 struct LLM_KV {
@@ -2055,10 +2063,10 @@ struct llama_vocab {
     int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
     int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
 
 
     id linefeed_id       = 13;
     id linefeed_id       = 13;
-    id special_prefix_id = 32007;
-    id special_middle_id = 32009;
-    id special_suffix_id = 32008;
-    id special_eot_id    = 32010;
+    id special_prefix_id = -1;
+    id special_suffix_id = -1;
+    id special_middle_id = -1;
+    id special_eot_id    = -1;
 
 
     bool add_space_prefix = true;
     bool add_space_prefix = true;
 
 
@@ -4072,6 +4080,30 @@ static void llm_load_vocab(
             vocab.special_cls_id  = -1;
             vocab.special_cls_id  = -1;
             vocab.special_mask_id = -1;
             vocab.special_mask_id = -1;
 
 
+            // For Fill-In-the-Middle (FIM)/infill models which where converted
+            // prior to support of FIM special tokens in GGUF, the following
+            // will allow those models to continue to work. The general names
+            // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and
+            // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once
+            // new versions of these models have been published.
+            std::string gen_name;
+            ml.get_key(LLM_KV_GENERAL_NAME, gen_name);
+            std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(),
+                [](unsigned char c){ return std::tolower(c); });
+            if (gen_name.find("code") != std::string::npos) {
+                if (model.arch == LLM_ARCH_LLAMA) {
+                    vocab.special_prefix_id = 32007;
+                    vocab.special_suffix_id = 32008;
+                    vocab.special_middle_id = 32009;
+                    vocab.special_eot_id    = 32010;
+                } else if (model.arch == LLM_ARCH_GEMMA) {
+                    vocab.special_prefix_id = 67;
+                    vocab.special_suffix_id = 69;
+                    vocab.special_middle_id = 68;
+                    vocab.special_eot_id    = 70;
+                }
+            }
+
             const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
             const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
             if (add_space_prefix_keyidx != -1) {
             if (add_space_prefix_keyidx != -1) {
                 vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
                 vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
@@ -4185,13 +4217,17 @@ static void llm_load_vocab(
     // special tokens
     // special tokens
     {
     {
         const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
         const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
-            { LLM_KV_TOKENIZER_BOS_ID,  vocab.special_bos_id  },
-            { LLM_KV_TOKENIZER_EOS_ID,  vocab.special_eos_id  },
-            { LLM_KV_TOKENIZER_UNK_ID,  vocab.special_unk_id  },
-            { LLM_KV_TOKENIZER_SEP_ID,  vocab.special_sep_id  },
-            { LLM_KV_TOKENIZER_PAD_ID,  vocab.special_pad_id  },
-            { LLM_KV_TOKENIZER_CLS_ID,  vocab.special_cls_id  },
-            { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
+            { LLM_KV_TOKENIZER_BOS_ID,    vocab.special_bos_id    },
+            { LLM_KV_TOKENIZER_EOS_ID,    vocab.special_eos_id    },
+            { LLM_KV_TOKENIZER_UNK_ID,    vocab.special_unk_id    },
+            { LLM_KV_TOKENIZER_SEP_ID,    vocab.special_sep_id    },
+            { LLM_KV_TOKENIZER_PAD_ID,    vocab.special_pad_id    },
+            { LLM_KV_TOKENIZER_CLS_ID,    vocab.special_cls_id    },
+            { LLM_KV_TOKENIZER_MASK_ID,   vocab.special_mask_id   },
+            { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
+            { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
+            { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
+            { LLM_KV_TOKENIZER_EOT_ID,    vocab.special_eot_id    },
         };
         };
         for (const auto & it : special_token_types) {
         for (const auto & it : special_token_types) {
             const std::string & key = kv(std::get<0>(it));
             const std::string & key = kv(std::get<0>(it));