소스 검색

llama : add Falcon3 support (#10883)

* Add Falcon3 model support

* Add fix for adding bos to added special tokens

* Add comment explaining the logic behind the if statement

* Add a log message to better track the when the following line of code is triggered

* Update log to only print when input and output characters are different

* Fix handling pre-normalized tokens

* Refactoring
Billel Mokeddem 1 년 전
부모
커밋
7ae33a616f
3개의 변경된 파일29개의 추가작업 그리고 1개의 파일을 삭제
  1. 13 0
      convert_hf_to_gguf.py
  2. 1 0
      convert_hf_to_gguf_update.py
  3. 15 1
      src/llama.cpp

+ 13 - 0
convert_hf_to_gguf.py

@@ -529,9 +529,19 @@ class Model:
             else:
             else:
                 token: str = reverse_vocab[i]
                 token: str = reverse_vocab[i]
                 if token in added_vocab:
                 if token in added_vocab:
+                    # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
+                    # To avoid unexpected issues - we make sure to normalize non-normalized tokens
+                    if not tokenizer.added_tokens_decoder[i].normalized:
+                        previous_token = token
+                        token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
+                        if previous_token != token:
+                            logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
+
                     if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token):
                     if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token):
                         toktypes.append(gguf.TokenType.CONTROL)
                         toktypes.append(gguf.TokenType.CONTROL)
                     else:
                     else:
+                        # NOTE: this was added for Gemma.
+                        # Encoding and decoding the tokens above isn't sufficient for this case.
                         token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ")  # pre-normalize user-defined spaces
                         token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ")  # pre-normalize user-defined spaces
                         toktypes.append(gguf.TokenType.USER_DEFINED)
                         toktypes.append(gguf.TokenType.USER_DEFINED)
                 else:
                 else:
@@ -575,6 +585,9 @@ class Model:
         if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed":
         if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed":
             # ref: https://huggingface.co/tiiuae/falcon-7b
             # ref: https://huggingface.co/tiiuae/falcon-7b
             res = "falcon"
             res = "falcon"
+        if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e":
+            # ref: https://huggingface.co/tiiuae/Falcon3-7B-Base
+            res = "falcon3"
         if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
         if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
             # ref: https://huggingface.co/BAAI/bge-small-en-v1.5
             # ref: https://huggingface.co/BAAI/bge-small-en-v1.5
             res = "bert-bge"
             res = "bert-bge"

+ 1 - 0
convert_hf_to_gguf_update.py

@@ -72,6 +72,7 @@ models = [
     {"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
     {"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
     {"name": "falcon",         "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
     {"name": "falcon",         "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
     {"name": "bert-bge",       "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
     {"name": "bert-bge",       "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
+    {"name": "falcon3",        "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon3-7B-Base", },
     {"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", },
     {"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", },
     {"name": "mpt",            "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
     {"name": "mpt",            "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
     {"name": "starcoder",      "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
     {"name": "starcoder",      "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },

+ 15 - 1
src/llama.cpp

@@ -1673,6 +1673,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
     LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
     LLM_CHAT_TEMPLATE_MISTRAL_V7,
     LLM_CHAT_TEMPLATE_MISTRAL_V7,
     LLM_CHAT_TEMPLATE_PHI_3,
     LLM_CHAT_TEMPLATE_PHI_3,
+    LLM_CHAT_TEMPLATE_FALCON_3,
     LLM_CHAT_TEMPLATE_ZEPHYR,
     LLM_CHAT_TEMPLATE_ZEPHYR,
     LLM_CHAT_TEMPLATE_MONARCH,
     LLM_CHAT_TEMPLATE_MONARCH,
     LLM_CHAT_TEMPLATE_GEMMA,
     LLM_CHAT_TEMPLATE_GEMMA,
@@ -1705,6 +1706,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
     { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
     { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
     { "mistral-v7",        LLM_CHAT_TEMPLATE_MISTRAL_V7        },
     { "mistral-v7",        LLM_CHAT_TEMPLATE_MISTRAL_V7        },
     { "phi3",              LLM_CHAT_TEMPLATE_PHI_3             },
     { "phi3",              LLM_CHAT_TEMPLATE_PHI_3             },
+    { "falcon3",           LLM_CHAT_TEMPLATE_FALCON_3          },
     { "zephyr",            LLM_CHAT_TEMPLATE_ZEPHYR            },
     { "zephyr",            LLM_CHAT_TEMPLATE_ZEPHYR            },
     { "monarch",           LLM_CHAT_TEMPLATE_MONARCH           },
     { "monarch",           LLM_CHAT_TEMPLATE_MONARCH           },
     { "gemma",             LLM_CHAT_TEMPLATE_GEMMA             },
     { "gemma",             LLM_CHAT_TEMPLATE_GEMMA             },
@@ -6562,7 +6564,8 @@ static void llm_load_vocab(
             } else if (
             } else if (
                     tokenizer_pre == "llama3"   ||
                     tokenizer_pre == "llama3"   ||
                     tokenizer_pre == "llama-v3" ||
                     tokenizer_pre == "llama-v3" ||
-                    tokenizer_pre == "llama-bpe") {
+                    tokenizer_pre == "llama-bpe"||
+                    tokenizer_pre == "falcon3") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
                 vocab.tokenizer_ignore_merges = true;
                 vocab.tokenizer_ignore_merges = true;
                 vocab.tokenizer_add_bos = true;
                 vocab.tokenizer_add_bos = true;
@@ -22615,6 +22618,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
         }
         }
     } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
     } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
         return LLM_CHAT_TEMPLATE_PHI_3;
         return LLM_CHAT_TEMPLATE_PHI_3;
+    } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
+        return LLM_CHAT_TEMPLATE_FALCON_3;
     } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
     } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
         return LLM_CHAT_TEMPLATE_ZEPHYR;
         return LLM_CHAT_TEMPLATE_ZEPHYR;
     } else if (tmpl_contains("bos_token + message['role']")) {
     } else if (tmpl_contains("bos_token + message['role']")) {
@@ -22767,6 +22772,15 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
         if (add_ass) {
             ss << "<|assistant|>\n";
             ss << "<|assistant|>\n";
         }
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) {
+        // Falcon 3
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|" << role << "|>\n" << message->content << "\n";
+        }
+        if (add_ass) {
+            ss << "<|assistant|>\n";
+        }
     } else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) {
     } else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) {
         // zephyr template
         // zephyr template
         for (auto message : chat) {
         for (auto message : chat) {