Переглянути джерело

convert-lora : make `--base` optional (#10110)

* convert-lora : make `--base` optional

* lint

* handle case where base_model_name_or_path is invalid

* do not include metadata from base model

* clarify unspecified --base

* add small comment [no ci]

* trigger ci
Xuan Son Nguyen 1 рік тому
батько
коміт
7554aa4655
2 змінених файлів з 51 додано та 23 видалено
  1. 14 13
      convert_hf_to_gguf.py
  2. 37 10
      convert_lora_to_gguf.py

+ 14 - 13
convert_hf_to_gguf.py

@@ -72,7 +72,8 @@ class Model:
     def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
     def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
                  use_temp_file: bool = False, eager: bool = False,
                  use_temp_file: bool = False, eager: bool = False,
                  metadata_override: Path | None = None, model_name: str | None = None,
                  metadata_override: Path | None = None, model_name: str | None = None,
-                 split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
+                 split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
+                 small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
         if type(self) is Model:
         if type(self) is Model:
             raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
             raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
 
 
@@ -87,7 +88,7 @@ class Model:
         self.is_safetensors = len(self.part_names) > 0
         self.is_safetensors = len(self.part_names) > 0
         if not self.is_safetensors:
         if not self.is_safetensors:
             self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
             self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
-        self.hparams = Model.load_hparams(self.dir_model)
+        self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams
         self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
         self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
         self.tensor_names = None
         self.tensor_names = None
@@ -1541,6 +1542,17 @@ class LlamaModel(Model):
             special_vocab._set_special_token("eot",    32010)
             special_vocab._set_special_token("eot",    32010)
             special_vocab.add_to_gguf(self.gguf_writer)
             special_vocab.add_to_gguf(self.gguf_writer)
 
 
+        tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
+        if tokenizer_config_file.is_file():
+            with open(tokenizer_config_file, "r", encoding="utf-8") as f:
+                tokenizer_config_json = json.load(f)
+                if "add_prefix_space" in tokenizer_config_json:
+                    self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
+
+        # Apply to granite small models only
+        if self.hparams.get("vocab_size", 32000) == 49152:
+            self.gguf_writer.add_add_bos_token(False)
+
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
         super().set_gguf_parameters()
         hparams = self.hparams
         hparams = self.hparams
@@ -1557,17 +1569,6 @@ class LlamaModel(Model):
                 self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
                 self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
 
 
-        tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
-        if tokenizer_config_file.is_file():
-            with open(tokenizer_config_file, "r", encoding="utf-8") as f:
-                tokenizer_config_json = json.load(f)
-                if "add_prefix_space" in tokenizer_config_json:
-                    self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
-
-        # Apply to granite small models only
-        if self.hparams.get("vocab_size", 32000) == 49152:
-            self.gguf_writer.add_add_bos_token(False)
-
     @staticmethod
     @staticmethod
     def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
     def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
         if n_head_kv is not None and n_head != n_head_kv:
         if n_head_kv is not None and n_head != n_head_kv:

+ 37 - 10
convert_lora_to_gguf.py

@@ -12,6 +12,7 @@ import json
 from math import prod
 from math import prod
 from pathlib import Path
 from pathlib import Path
 from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
 from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
+from transformers import AutoConfig
 
 
 import torch
 import torch
 
 
@@ -256,8 +257,8 @@ def parse_args() -> argparse.Namespace:
         help="only print out what will be done, without writing any new files",
         help="only print out what will be done, without writing any new files",
     )
     )
     parser.add_argument(
     parser.add_argument(
-        "--base", type=Path, required=True,
-        help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required",
+        "--base", type=Path,
+        help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config",
     )
     )
     parser.add_argument(
     parser.add_argument(
         "lora_path", type=Path,
         "lora_path", type=Path,
@@ -267,6 +268,12 @@ def parse_args() -> argparse.Namespace:
     return parser.parse_args()
     return parser.parse_args()
 
 
 
 
+def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
+    # normally, adapter does not come with base model config, we need to load it from AutoConfig
+    config = AutoConfig.from_pretrained(hf_model_id)
+    return config.to_dict()
+
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     args = parse_args()
     args = parse_args()
     logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
     logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
@@ -281,7 +288,7 @@ if __name__ == '__main__':
 
 
     ftype = ftype_map[args.outtype]
     ftype = ftype_map[args.outtype]
 
 
-    dir_base_model: Path = args.base
+    dir_base_model: Path | None = args.base
     dir_lora: Path = args.lora_path
     dir_lora: Path = args.lora_path
     lora_config = dir_lora / "adapter_config.json"
     lora_config = dir_lora / "adapter_config.json"
     input_model = dir_lora / "adapter_model.safetensors"
     input_model = dir_lora / "adapter_model.safetensors"
@@ -301,9 +308,29 @@ if __name__ == '__main__':
         input_model = os.path.join(dir_lora, "adapter_model.bin")
         input_model = os.path.join(dir_lora, "adapter_model.bin")
         lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
         lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
 
 
+    # load LoRA config
+    with open(lora_config, "r") as f:
+        lparams: dict[str, Any] = json.load(f)
+
     # load base model
     # load base model
-    logger.info(f"Loading base model: {dir_base_model.name}")
-    hparams = Model.load_hparams(dir_base_model)
+    if dir_base_model is None:
+        if "base_model_name_or_path" in lparams:
+            model_id = lparams["base_model_name_or_path"]
+            logger.info(f"Loading base model from Hugging Face: {model_id}")
+            try:
+                hparams = load_hparams_from_hf(model_id)
+            except OSError as e:
+                logger.error(f"Failed to load base model config: {e}")
+                logger.error("Please try downloading the base model and add its path to --base")
+                sys.exit(1)
+        else:
+            logger.error("'base_model_name_or_path' is not found in adapter_config.json")
+            logger.error("Base model config is required. Please download the base model and add its path to --base")
+            sys.exit(1)
+    else:
+        logger.info(f"Loading base model: {dir_base_model.name}")
+        hparams = Model.load_hparams(dir_base_model)
+
     with torch.inference_mode():
     with torch.inference_mode():
         try:
         try:
             model_class = Model.from_model_architecture(hparams["architectures"][0])
             model_class = Model.from_model_architecture(hparams["architectures"][0])
@@ -323,13 +350,15 @@ if __name__ == '__main__':
                 self.dir_model_card = dir_lora_model
                 self.dir_model_card = dir_lora_model
                 self.lora_alpha = float(lora_alpha)
                 self.lora_alpha = float(lora_alpha)
 
 
+            def set_vocab(self):
+                pass
+
             def set_type(self):
             def set_type(self):
                 self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
                 self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
                 self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
                 self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
 
 
             def set_gguf_parameters(self):
             def set_gguf_parameters(self):
                 self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
                 self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
-                super().set_gguf_parameters()
 
 
             def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
             def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
                 # Never add extra tensors (e.g. rope_freqs) for LoRA adapters
                 # Never add extra tensors (e.g. rope_freqs) for LoRA adapters
@@ -350,7 +379,7 @@ if __name__ == '__main__':
                         logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
                         logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
                         if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
                         if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
                             logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
                             logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
-                            logger.error("Hint: if you are using TRL, make sure not to call setup_chat_format()")
+                            logger.error("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948")
                         sys.exit(1)
                         sys.exit(1)
 
 
                     if base_name in tensor_map:
                     if base_name in tensor_map:
@@ -384,9 +413,6 @@ if __name__ == '__main__':
                     yield (dest_name + ".lora_a", lora_a)
                     yield (dest_name + ".lora_a", lora_a)
                     yield (dest_name + ".lora_b", lora_b)
                     yield (dest_name + ".lora_b", lora_b)
 
 
-        with open(lora_config, "r") as f:
-            lparams: dict[str, Any] = json.load(f)
-
         alpha: float = lparams["lora_alpha"]
         alpha: float = lparams["lora_alpha"]
 
 
         model_instance = LoraModel(
         model_instance = LoraModel(
@@ -399,6 +425,7 @@ if __name__ == '__main__':
             dry_run=args.dry_run,
             dry_run=args.dry_run,
             dir_lora_model=dir_lora,
             dir_lora_model=dir_lora,
             lora_alpha=alpha,
             lora_alpha=alpha,
+            hparams=hparams,
         )
         )
 
 
         logger.info("Exporting model...")
         logger.info("Exporting model...")