Parcourir la source

convert-*.py: GGUF Naming Convention Refactor and Metadata Override Refactor (#7499)

Main thing is that the default output filename will take this form

{name}{parameters}{finetune}{version}{encoding}{kind}

In addition this add and remove some entries in the KV store and adds a metadata class with automatic heuristics capability to derive some values based on model card content

* No Change:
  - Internal GGUF Spec
    - `general.architecture`
    - `general.quantization_version`
    - `general.alignment`
    - `general.file_type`
  - General Model Details
    - `general.name`
    - `general.author`
    - `general.version`
    - `general.description`
  - Licensing details
    - `general.license`
  - Typically represents the converted GGUF repo (Unless made from scratch)
    - `general.url`
  - Model Source during conversion
    - `general.source.url`

* Removed:
  - Model Source during conversion
    - `general.source.huggingface.repository`

* Added:
  - General Model Details
    - `general.organization`
    - `general.finetune`
    - `general.basename`
    - `general.quantized_by`
    - `general.size_label`
  - Licensing details
    - `general.license.name`
    - `general.license.link`
  - Typically represents the converted GGUF repo (Unless made from scratch)
    - `general.doi`
    - `general.uuid`
    - `general.repo_url`
  - Model Source during conversion
    - `general.source.doi`
    - `general.source.uuid`
    - `general.source.repo_url`
  - Base Model Source
    - `general.base_model.count`
    - `general.base_model.{id}.name`
    - `general.base_model.{id}.author`
    - `general.base_model.{id}.version`
    - `general.base_model.{id}.organization`
    - `general.base_model.{id}.url` (Model Website/Paper)
    - `general.base_model.{id}.doi`
    - `general.base_model.{id}.uuid`
    - `general.base_model.{id}.repo_url` (Model Source Repository (git/svn/etc...))
  - Array based KV stores
    - `general.tags`
    - `general.languages`
    - `general.datasets`

---------

Co-authored-by: compilade <git@compilade.net>
Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Brian il y a 1 an
Parent
commit
672a6f1018

+ 112 - 75
convert_hf_to_gguf.py

@@ -48,34 +48,38 @@ class Model:
 
 
     dir_model: Path
     dir_model: Path
     ftype: gguf.LlamaFileType
     ftype: gguf.LlamaFileType
+    fname_out: Path | None
     is_big_endian: bool
     is_big_endian: bool
     endianess: gguf.GGUFEndian
     endianess: gguf.GGUFEndian
     use_temp_file: bool
     use_temp_file: bool
     lazy: bool
     lazy: bool
-    model_name: str | None
     part_names: list[str]
     part_names: list[str]
     is_safetensors: bool
     is_safetensors: bool
     hparams: dict[str, Any]
     hparams: dict[str, Any]
     block_count: int
     block_count: int
     tensor_map: gguf.TensorNameMap
     tensor_map: gguf.TensorNameMap
     tensor_names: set[str] | None
     tensor_names: set[str] | None
-    fname_out: Path
     gguf_writer: gguf.GGUFWriter
     gguf_writer: gguf.GGUFWriter
+    model_name: str | None
+    metadata_override: Path | None
 
 
     # subclasses should define this!
     # subclasses should define this!
     model_arch: gguf.MODEL_ARCH
     model_arch: gguf.MODEL_ARCH
 
 
-    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool,
-                 model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
+    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool = False,
+                 use_temp_file: bool = False, eager: bool = False,
+                 metadata_override: Path | None = None, model_name: str | None = None,
+                 split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
         if type(self) is Model:
         if type(self) is Model:
             raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
             raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
+
         self.dir_model = dir_model
         self.dir_model = dir_model
         self.ftype = ftype
         self.ftype = ftype
+        self.fname_out = fname_out
         self.is_big_endian = is_big_endian
         self.is_big_endian = is_big_endian
         self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
         self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
         self.use_temp_file = use_temp_file
         self.use_temp_file = use_temp_file
         self.lazy = not eager
         self.lazy = not eager
-        self.model_name = model_name
         self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
         self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
         self.is_safetensors = len(self.part_names) > 0
         self.is_safetensors = len(self.part_names) > 0
         if not self.is_safetensors:
         if not self.is_safetensors:
@@ -84,6 +88,10 @@ class Model:
         self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
         self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
         self.tensor_names = None
         self.tensor_names = None
+        self.metadata_override = metadata_override
+        self.model_name = model_name
+
+        # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
         if self.ftype == gguf.LlamaFileType.GUESSED:
         if self.ftype == gguf.LlamaFileType.GUESSED:
             # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
             # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
             _, first_tensor = next(self.get_tensors())
             _, first_tensor = next(self.get_tensors())
@@ -93,10 +101,8 @@ class Model:
             else:
             else:
                 logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
                 logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
                 self.ftype = gguf.LlamaFileType.MOSTLY_BF16
                 self.ftype = gguf.LlamaFileType.MOSTLY_BF16
-        ftype_up: str = self.ftype.name.partition("_")[2].upper()
-        ftype_lw: str = ftype_up.lower()
-        # allow templating the file name with the output ftype, useful with the "auto" ftype
-        self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
+
+        # Configure GGUF Writer
         self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
         self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
                                            split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
                                            split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
 
 
@@ -193,7 +199,6 @@ class Model:
         return new_name
         return new_name
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_block_count(self.block_count)
         self.gguf_writer.add_block_count(self.block_count)
 
 
         if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
         if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
@@ -250,7 +255,7 @@ class Model:
 
 
         return False
         return False
 
 
-    def write_tensors(self):
+    def prepare_tensors(self):
         max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
         max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
 
 
         for name, data_torch in self.get_tensors():
         for name, data_torch in self.get_tensors():
@@ -333,9 +338,67 @@ class Model:
 
 
                 self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
                 self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
 
 
+    def set_type(self):
+        self.gguf_writer.add_type(gguf.GGUFType.MODEL)
+
+    def prepare_metadata(self, vocab_only: bool):
+
+        total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count()
+
+        self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model, self.model_name, total_params)
+
+        # Fallback to model directory name if metadata name is still missing
+        if self.metadata.name is None:
+            self.metadata.name = self.dir_model.name
+
+        # Generate parameter weight class (useful for leader boards) if not yet determined
+        if self.metadata.size_label is None and total_params > 0:
+            self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
+
+        # Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0'
+        output_type: str = self.ftype.name.partition("_")[2]
+
+        # Filename Output
+        # Note: `not is_dir()` is used because `.is_file()` will not detect
+        #       file template strings as it doesn't actually exist as a file
+        if self.fname_out is not None and not self.fname_out.is_dir():
+            # Output path is a custom defined templated filename
+
+            # Process templated file name with the output ftype, useful with the "auto" ftype
+            self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
+        else:
+            # Generate default filename based on model specification and available metadata
+            if not vocab_only:
+                fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type, model_type="LoRA" if total_params < 0 else None)
+            else:
+                fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=None, model_type="vocab")
+
+            # Check if preferred output directory path was provided
+            if self.fname_out is not None and self.fname_out.is_dir():
+                # output path is a directory
+                self.fname_out = self.fname_out / f"{fname_default}.gguf"
+            else:
+                # output in the same directory as the model by default
+                self.fname_out = self.dir_model / f"{fname_default}.gguf"
+
+        self.set_type()
+
+        logger.info("Set meta model")
+        self.metadata.set_gguf_meta_model(self.gguf_writer)
+
+        logger.info("Set model parameters")
+        self.set_gguf_parameters()
+
+        logger.info("Set model tokenizer")
+        self.set_vocab()
+
+        logger.info("Set model quantization version")
+        self.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
+
     def write(self):
     def write(self):
-        self.write_tensors()
-        self.gguf_writer.write_header_to_file(self.fname_out)
+        self.prepare_tensors()
+        self.prepare_metadata(vocab_only=False)
+        self.gguf_writer.write_header_to_file(path=self.fname_out)
         self.gguf_writer.write_kv_data_to_file()
         self.gguf_writer.write_kv_data_to_file()
         self.gguf_writer.write_tensors_to_file(progress=True)
         self.gguf_writer.write_tensors_to_file(progress=True)
         self.gguf_writer.close()
         self.gguf_writer.close()
@@ -343,7 +406,9 @@ class Model:
     def write_vocab(self):
     def write_vocab(self):
         if len(self.gguf_writer.tensors) != 1:
         if len(self.gguf_writer.tensors) != 1:
             raise ValueError('Splitting the vocabulary is not supported')
             raise ValueError('Splitting the vocabulary is not supported')
-        self.gguf_writer.write_header_to_file(self.fname_out)
+
+        self.prepare_metadata(vocab_only=True)
+        self.gguf_writer.write_header_to_file(path=self.fname_out)
         self.gguf_writer.write_kv_data_to_file()
         self.gguf_writer.write_kv_data_to_file()
         self.gguf_writer.close()
         self.gguf_writer.close()
 
 
@@ -780,7 +845,6 @@ class GPTNeoXModel(Model):
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
         block_count = self.hparams["num_hidden_layers"]
         block_count = self.hparams["num_hidden_layers"]
 
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_block_count(block_count)
         self.gguf_writer.add_block_count(block_count)
@@ -836,7 +900,6 @@ class BloomModel(Model):
     model_arch = gguf.MODEL_ARCH.BLOOM
     model_arch = gguf.MODEL_ARCH.BLOOM
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name("Bloom")
         n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
         n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
         n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
         n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
         self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
         self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
@@ -913,7 +976,6 @@ class MPTModel(Model):
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
         block_count = self.hparams["n_layers"]
         block_count = self.hparams["n_layers"]
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
         self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
         self.gguf_writer.add_embedding_length(self.hparams["d_model"])
         self.gguf_writer.add_embedding_length(self.hparams["d_model"])
         self.gguf_writer.add_block_count(block_count)
         self.gguf_writer.add_block_count(block_count)
@@ -952,7 +1014,6 @@ class OrionModel(Model):
         block_count = self.hparams["num_hidden_layers"]
         block_count = self.hparams["num_hidden_layers"]
         head_count = self.hparams["num_attention_heads"]
         head_count = self.hparams["num_attention_heads"]
         head_count_kv = self.hparams.get("num_key_value_heads", head_count)
         head_count_kv = self.hparams.get("num_key_value_heads", head_count)
-        hf_repo = self.hparams.get("_name_or_path", "")
 
 
         ctx_length = 0
         ctx_length = 0
         if "max_sequence_length" in self.hparams:
         if "max_sequence_length" in self.hparams:
@@ -965,8 +1026,6 @@ class OrionModel(Model):
             raise ValueError("gguf: can not find ctx length parameter.")
             raise ValueError("gguf: can not find ctx length parameter.")
 
 
         self.gguf_writer.add_file_type(self.ftype)
         self.gguf_writer.add_file_type(self.ftype)
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
-        self.gguf_writer.add_source_hf_repo(hf_repo)
         self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
         self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
         self.gguf_writer.add_context_length(ctx_length)
         self.gguf_writer.add_context_length(ctx_length)
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -990,7 +1049,6 @@ class BaichuanModel(Model):
         block_count = self.hparams["num_hidden_layers"]
         block_count = self.hparams["num_hidden_layers"]
         head_count = self.hparams["num_attention_heads"]
         head_count = self.hparams["num_attention_heads"]
         head_count_kv = self.hparams.get("num_key_value_heads", head_count)
         head_count_kv = self.hparams.get("num_key_value_heads", head_count)
-        hf_repo = self.hparams.get("_name_or_path", "")
 
 
         ctx_length = 0
         ctx_length = 0
         if "max_sequence_length" in self.hparams:
         if "max_sequence_length" in self.hparams:
@@ -1002,8 +1060,6 @@ class BaichuanModel(Model):
         else:
         else:
             raise ValueError("gguf: can not find ctx length parameter.")
             raise ValueError("gguf: can not find ctx length parameter.")
 
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
-        self.gguf_writer.add_source_hf_repo(hf_repo)
         self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
         self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
         self.gguf_writer.add_context_length(ctx_length)
         self.gguf_writer.add_context_length(ctx_length)
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -1117,7 +1173,6 @@ class XverseModel(Model):
         block_count = self.hparams["num_hidden_layers"]
         block_count = self.hparams["num_hidden_layers"]
         head_count = self.hparams["num_attention_heads"]
         head_count = self.hparams["num_attention_heads"]
         head_count_kv = self.hparams.get("num_key_value_heads", head_count)
         head_count_kv = self.hparams.get("num_key_value_heads", head_count)
-        hf_repo = self.hparams.get("_name_or_path", "")
 
 
         ctx_length = 0
         ctx_length = 0
         if "max_sequence_length" in self.hparams:
         if "max_sequence_length" in self.hparams:
@@ -1129,8 +1184,6 @@ class XverseModel(Model):
         else:
         else:
             raise ValueError("gguf: can not find ctx length parameter.")
             raise ValueError("gguf: can not find ctx length parameter.")
 
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
-        self.gguf_writer.add_source_hf_repo(hf_repo)
         self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
         self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
         self.gguf_writer.add_context_length(ctx_length)
         self.gguf_writer.add_context_length(ctx_length)
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -1189,7 +1242,6 @@ class FalconModel(Model):
         if n_head_kv is None:
         if n_head_kv is None:
             n_head_kv = self.hparams.get("n_head_kv", 1)  # old name
             n_head_kv = self.hparams.get("n_head_kv", 1)  # old name
 
 
-        self.gguf_writer.add_name("Falcon")
         self.gguf_writer.add_context_length(2048)  # not in config.json
         self.gguf_writer.add_context_length(2048)  # not in config.json
         self.gguf_writer.add_tensor_data_layout("jploski")  # qkv tensor transform
         self.gguf_writer.add_tensor_data_layout("jploski")  # qkv tensor transform
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -1234,7 +1286,6 @@ class StarCoderModel(Model):
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
         block_count = self.hparams["n_layer"]
         block_count = self.hparams["n_layer"]
 
 
-        self.gguf_writer.add_name("StarCoder")
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
         self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
         self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
@@ -1269,7 +1320,6 @@ class RefactModel(Model):
 
 
         block_count = self.hparams["n_layer"]
         block_count = self.hparams["n_layer"]
 
 
-        self.gguf_writer.add_name("Refact")
         # refact uses Alibi. So this is from config.json which might be used by training.
         # refact uses Alibi. So this is from config.json which might be used by training.
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
@@ -1324,7 +1374,6 @@ class StableLMModel(Model):
         hparams = self.hparams
         hparams = self.hparams
         block_count = hparams["num_hidden_layers"]
         block_count = hparams["num_hidden_layers"]
 
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
         self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_block_count(block_count)
         self.gguf_writer.add_block_count(block_count)
@@ -1386,8 +1435,8 @@ class StableLMModel(Model):
 
 
         return [(new_name, data_torch)]
         return [(new_name, data_torch)]
 
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
 
 
         if self._q_norms is not None or self._k_norms is not None:
         if self._q_norms is not None or self._k_norms is not None:
             # flatten two `list[dict[str, Tensor]]` into a single `list[str]`
             # flatten two `list[dict[str, Tensor]]` into a single `list[str]`
@@ -1503,8 +1552,8 @@ class LlamaModel(Model):
 
 
         return [(self.map_tensor_name(name), data_torch)]
         return [(self.map_tensor_name(name), data_torch)]
 
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
 
 
         if self._experts is not None:
         if self._experts is not None:
             # flatten `list[dict[str, Tensor]]` into `list[str]`
             # flatten `list[dict[str, Tensor]]` into `list[str]`
@@ -1567,7 +1616,6 @@ class GrokModel(Model):
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
         super().set_gguf_parameters()
-        self.gguf_writer.add_name("Grok")
 
 
     _experts: list[dict[str, Tensor]] | None = None
     _experts: list[dict[str, Tensor]] | None = None
 
 
@@ -1616,7 +1664,6 @@ class DbrxModel(Model):
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
         ffn_config = self.hparams["ffn_config"]
         ffn_config = self.hparams["ffn_config"]
         attn_config = self.hparams["attn_config"]
         attn_config = self.hparams["attn_config"]
-        self.gguf_writer.add_name(self.hparams["model_type"])
         self.gguf_writer.add_block_count(self.hparams["n_layers"])
         self.gguf_writer.add_block_count(self.hparams["n_layers"])
 
 
         self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
         self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
@@ -1685,7 +1732,6 @@ class MiniCPMModel(Model):
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
         block_count = self.hparams["num_hidden_layers"]
         block_count = self.hparams["num_hidden_layers"]
-        self.gguf_writer.add_name("MiniCPM")
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_block_count(block_count)
         self.gguf_writer.add_block_count(block_count)
@@ -1755,7 +1801,6 @@ class QwenModel(Model):
         self._set_vocab_qwen()
         self._set_vocab_qwen()
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name("Qwen")
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
         self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -1831,8 +1876,8 @@ class Qwen2MoeModel(Model):
 
 
         return [(self.map_tensor_name(name), data_torch)]
         return [(self.map_tensor_name(name), data_torch)]
 
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
 
 
         if self._experts is not None:
         if self._experts is not None:
             # flatten `list[dict[str, Tensor]]` into `list[str]`
             # flatten `list[dict[str, Tensor]]` into `list[str]`
@@ -1846,7 +1891,6 @@ class GPT2Model(Model):
     model_arch = gguf.MODEL_ARCH.GPT2
     model_arch = gguf.MODEL_ARCH.GPT2
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_block_count(self.hparams["n_layer"])
         self.gguf_writer.add_block_count(self.hparams["n_layer"])
         self.gguf_writer.add_context_length(self.hparams["n_ctx"])
         self.gguf_writer.add_context_length(self.hparams["n_ctx"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
@@ -1889,7 +1933,6 @@ class Phi2Model(Model):
         n_embd = self.find_hparam(["hidden_size", "n_embd"])
         n_embd = self.find_hparam(["hidden_size", "n_embd"])
         n_head = self.find_hparam(["num_attention_heads", "n_head"])
         n_head = self.find_hparam(["num_attention_heads", "n_head"])
 
 
-        self.gguf_writer.add_name("Phi2")
         self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
         self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
 
 
         self.gguf_writer.add_embedding_length(n_embd)
         self.gguf_writer.add_embedding_length(n_embd)
@@ -2011,7 +2054,6 @@ class Phi3MiniModel(Model):
         orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
         orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
         rope_dims = n_embd // n_head
         rope_dims = n_embd // n_head
 
 
-        self.gguf_writer.add_name("Phi3")
         self.gguf_writer.add_context_length(max_pos_embds)
         self.gguf_writer.add_context_length(max_pos_embds)
         self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
         self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
         self.gguf_writer.add_embedding_length(n_embd)
         self.gguf_writer.add_embedding_length(n_embd)
@@ -2068,7 +2110,6 @@ class PlamoModel(Model):
         hparams = self.hparams
         hparams = self.hparams
         block_count = hparams["num_hidden_layers"]
         block_count = hparams["num_hidden_layers"]
 
 
-        self.gguf_writer.add_name("PLaMo")
         self.gguf_writer.add_context_length(4096)  # not in config.json
         self.gguf_writer.add_context_length(4096)  # not in config.json
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
         self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
@@ -2113,7 +2154,6 @@ class CodeShellModel(Model):
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
         block_count = self.hparams["n_layer"]
         block_count = self.hparams["n_layer"]
 
 
-        self.gguf_writer.add_name("CodeShell")
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
         self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
         self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
@@ -2272,7 +2312,6 @@ class InternLM2Model(Model):
         special_vocab.add_to_gguf(self.gguf_writer)
         special_vocab.add_to_gguf(self.gguf_writer)
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name("InternLM2")
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
         self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -2440,7 +2479,6 @@ class GemmaModel(Model):
         hparams = self.hparams
         hparams = self.hparams
         block_count = hparams["num_hidden_layers"]
         block_count = hparams["num_hidden_layers"]
 
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
         self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_block_count(block_count)
         self.gguf_writer.add_block_count(block_count)
@@ -2481,7 +2519,6 @@ class Gemma2Model(Model):
         hparams = self.hparams
         hparams = self.hparams
         block_count = hparams["num_hidden_layers"]
         block_count = hparams["num_hidden_layers"]
 
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
         self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_block_count(block_count)
         self.gguf_writer.add_block_count(block_count)
@@ -2556,7 +2593,6 @@ class MambaModel(Model):
         # Fail early for models which don't have a block expansion factor of 2
         # Fail early for models which don't have a block expansion factor of 2
         assert d_inner == 2 * d_model
         assert d_inner == 2 * d_model
 
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
         self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
         self.gguf_writer.add_embedding_length(d_model)
         self.gguf_writer.add_embedding_length(d_model)
         self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
         self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
@@ -2735,7 +2771,6 @@ class OpenELMModel(Model):
         assert self.block_count == len(self._num_query_heads)
         assert self.block_count == len(self._num_query_heads)
         assert self.block_count == len(self._ffn_dims)
         assert self.block_count == len(self._ffn_dims)
 
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_block_count(self.block_count)
         self.gguf_writer.add_block_count(self.block_count)
         self.gguf_writer.add_context_length(self.hparams["max_context_length"])
         self.gguf_writer.add_context_length(self.hparams["max_context_length"])
         self.gguf_writer.add_embedding_length(n_embd)
         self.gguf_writer.add_embedding_length(n_embd)
@@ -2909,8 +2944,8 @@ class ArcticModel(Model):
 
 
         return [(self.map_tensor_name(name), data_torch)]
         return [(self.map_tensor_name(name), data_torch)]
 
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
 
 
         if self._experts is not None:
         if self._experts is not None:
             # flatten `list[dict[str, Tensor]]` into `list[str]`
             # flatten `list[dict[str, Tensor]]` into `list[str]`
@@ -2988,8 +3023,8 @@ class DeepseekV2Model(Model):
 
 
         return [(self.map_tensor_name(name), data_torch)]
         return [(self.map_tensor_name(name), data_torch)]
 
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
 
 
         if self._experts is not None:
         if self._experts is not None:
             # flatten `list[dict[str, Tensor]]` into `list[str]`
             # flatten `list[dict[str, Tensor]]` into `list[str]`
@@ -3107,7 +3142,6 @@ class T5Model(Model):
         self.gguf_writer.add_add_eos_token(True)
         self.gguf_writer.add_add_eos_token(True)
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name("T5")
         if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
         if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
             logger.warning("Couldn't find context length in config.json, assuming default value of 512")
             logger.warning("Couldn't find context length in config.json, assuming default value of 512")
             n_ctx = 512
             n_ctx = 512
@@ -3181,7 +3215,6 @@ class JaisModel(Model):
         self._set_vocab_gpt2()
         self._set_vocab_gpt2()
 
 
     def set_gguf_parameters(self):
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name(self.dir_model.name)
         self.gguf_writer.add_block_count(self.hparams["n_layer"])
         self.gguf_writer.add_block_count(self.hparams["n_layer"])
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
@@ -3227,8 +3260,8 @@ class JaisModel(Model):
 
 
         return tensors
         return tensors
 
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
         self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
         self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
 
 
 
 
@@ -3539,6 +3572,10 @@ def parse_args() -> argparse.Namespace:
         "--no-tensor-first-split", action="store_true",
         "--no-tensor-first-split", action="store_true",
         help="do not add tensors to the first split (disabled by default)"
         help="do not add tensors to the first split (disabled by default)"
     )
     )
+    parser.add_argument(
+        "--metadata", type=Path,
+        help="Specify the path for an authorship metadata override file"
+    )
 
 
     return parser.parse_args()
     return parser.parse_args()
 
 
@@ -3564,7 +3601,10 @@ def split_str_to_n_bytes(split_str: str) -> int:
 def main() -> None:
 def main() -> None:
     args = parse_args()
     args = parse_args()
 
 
-    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+    if args.verbose:
+        logging.basicConfig(level=logging.DEBUG)
+    else:
+        logging.basicConfig(level=logging.INFO)
 
 
     dir_model = args.model
     dir_model = args.model
 
 
@@ -3585,37 +3625,33 @@ def main() -> None:
         logger.error("Error: Cannot use temp file when splitting")
         logger.error("Error: Cannot use temp file when splitting")
         sys.exit(1)
         sys.exit(1)
 
 
+    fname_out = None
+
     if args.outfile is not None:
     if args.outfile is not None:
         fname_out = args.outfile
         fname_out = args.outfile
-    else:
-        # output in the same directory as the model by default
-        fname_out = dir_model / 'ggml-model-{ftype}.gguf'
 
 
     logger.info(f"Loading model: {dir_model.name}")
     logger.info(f"Loading model: {dir_model.name}")
 
 
     hparams = Model.load_hparams(dir_model)
     hparams = Model.load_hparams(dir_model)
 
 
     with torch.inference_mode():
     with torch.inference_mode():
+        output_type = ftype_map[args.outtype]
+        model_architecture = hparams["architectures"][0]
+
         try:
         try:
-            model_class = Model.from_model_architecture(hparams["architectures"][0])
+            model_class = Model.from_model_architecture(model_architecture)
         except NotImplementedError:
         except NotImplementedError:
-            logger.error(f"Model {hparams['architectures'][0]} is not supported")
+            logger.error(f"Model {model_architecture} is not supported")
             sys.exit(1)
             sys.exit(1)
 
 
-        model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file,
-                                     args.no_lazy, args.model_name, split_max_tensors=args.split_max_tensors,
+        model_instance = model_class(dir_model=dir_model, ftype=output_type, fname_out=fname_out,
+                                     is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
+                                     eager=args.no_lazy,
+                                     metadata_override=args.metadata, model_name=args.model_name,
+                                     split_max_tensors=args.split_max_tensors,
                                      split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
                                      split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
                                      small_first_shard=args.no_tensor_first_split)
                                      small_first_shard=args.no_tensor_first_split)
 
 
-        logger.info("Set model parameters")
-        model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL)
-        model_instance.set_gguf_parameters()
-
-        logger.info("Set model tokenizer")
-        model_instance.set_vocab()
-
-        model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
-
         if args.vocab_only:
         if args.vocab_only:
             logger.info("Exporting model vocab...")
             logger.info("Exporting model vocab...")
             model_instance.write_vocab()
             model_instance.write_vocab()
@@ -3623,6 +3659,7 @@ def main() -> None:
         else:
         else:
             logger.info("Exporting model...")
             logger.info("Exporting model...")
             model_instance.write()
             model_instance.write()
+            assert model_instance.fname_out is not None
             out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out
             out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out
             logger.info(f"Model successfully exported to {out_path}")
             logger.info(f"Model successfully exported to {out_path}")
 
 

+ 19 - 10
convert_lora_to_gguf.py

@@ -251,6 +251,10 @@ def parse_args() -> argparse.Namespace:
         "--verbose", action="store_true",
         "--verbose", action="store_true",
         help="increase output verbosity",
         help="increase output verbosity",
     )
     )
+    parser.add_argument(
+        "--dry-run", action="store_true",
+        help="only print out what will be done, without writing any new files",
+    )
     parser.add_argument(
     parser.add_argument(
         "--base", type=Path, required=True,
         "--base", type=Path, required=True,
         help="directory containing base model file",
         help="directory containing base model file",
@@ -300,6 +304,12 @@ if __name__ == '__main__':
     # load base model
     # load base model
     logger.info(f"Loading base model: {dir_base_model.name}")
     logger.info(f"Loading base model: {dir_base_model.name}")
     hparams = Model.load_hparams(dir_base_model)
     hparams = Model.load_hparams(dir_base_model)
+
+    with open(lora_config, "r") as f:
+        lparams: dict[str, Any] = json.load(f)
+
+    alpha: float = lparams["lora_alpha"]
+
     with torch.inference_mode():
     with torch.inference_mode():
         try:
         try:
             model_class = Model.from_model_architecture(hparams["architectures"][0])
             model_class = Model.from_model_architecture(hparams["architectures"][0])
@@ -310,6 +320,14 @@ if __name__ == '__main__':
         class LoraModel(model_class):
         class LoraModel(model_class):
             model_arch = model_class.model_arch
             model_arch = model_class.model_arch
 
 
+            def set_type(self):
+                self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
+                self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
+
+            def set_gguf_parameters(self):
+                self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
+                super().set_gguf_parameters()
+
             def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
             def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
                 tensor_map: dict[str, PartialLoraTensor] = {}
                 tensor_map: dict[str, PartialLoraTensor] = {}
 
 
@@ -357,18 +375,9 @@ if __name__ == '__main__':
             is_big_endian=args.bigendian,
             is_big_endian=args.bigendian,
             use_temp_file=False,
             use_temp_file=False,
             eager=args.no_lazy,
             eager=args.no_lazy,
-            model_name=None,
+            dry_run=args.dry_run,
         )
         )
 
 
-        with open(lora_config, "r") as f:
-            lparams: dict[str, Any] = json.load(f)
-
-        alpha = lparams["lora_alpha"]
-
-        model_instance.gguf_writer.add_string(gguf.Keys.General.TYPE, gguf.GGUFType.ADAPTER)
-        model_instance.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
-        model_instance.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
-        model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
         logger.info("Exporting model...")
         logger.info("Exporting model...")
         model_instance.write()
         model_instance.write()
         logger.info(f"Model successfully exported to {model_instance.fname_out}")
         logger.info(f"Model successfully exported to {model_instance.fname_out}")

+ 128 - 107
examples/convert_legacy_llama.py

@@ -24,7 +24,7 @@ from abc import ABC, abstractmethod
 from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
 from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
 from dataclasses import dataclass
 from dataclasses import dataclass
 from pathlib import Path
 from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar, Optional
+from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar
 
 
 import numpy as np
 import numpy as np
 
 
@@ -346,42 +346,6 @@ class Params:
         return params
         return params
 
 
 
 
-@dataclass
-class Metadata:
-    name: Optional[str] = None
-    author: Optional[str] = None
-    version: Optional[str] = None
-    url: Optional[str] = None
-    description: Optional[str] = None
-    license: Optional[str] = None
-    source_url: Optional[str] = None
-    source_hf_repo: Optional[str] = None
-
-    @staticmethod
-    def load(metadata_path: Path) -> Metadata:
-        if metadata_path is None or not metadata_path.exists():
-            return Metadata()
-
-        with open(metadata_path, 'r') as file:
-            data = json.load(file)
-
-        # Create a new Metadata instance
-        metadata = Metadata()
-
-        # Assigning values to Metadata attributes if they exist in the JSON file
-        # This is based on LLM_KV_NAMES mapping in llama.cpp
-        metadata.name = data.get("general.name")
-        metadata.author = data.get("general.author")
-        metadata.version = data.get("general.version")
-        metadata.url = data.get("general.url")
-        metadata.description = data.get("general.description")
-        metadata.license = data.get("general.license")
-        metadata.source_url = data.get("general.source.url")
-        metadata.source_hf_repo = data.get("general.source.huggingface.repository")
-
-        return metadata
-
-
 #
 #
 # data loading
 # data loading
 # TODO: reuse (probably move to gguf.py?)
 # TODO: reuse (probably move to gguf.py?)
@@ -806,7 +770,7 @@ class OutputFile:
     def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
     def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
         self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
         self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
 
 
-    def add_meta_model(self, params: Params, metadata: Metadata | None) -> None:
+    def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None:
         # Metadata About The Model And Its Provenence
         # Metadata About The Model And Its Provenence
         name = "LLaMA"
         name = "LLaMA"
         if metadata is not None and metadata.name is not None:
         if metadata is not None and metadata.name is not None:
@@ -824,16 +788,73 @@ class OutputFile:
                 self.gguf.add_author(metadata.author)
                 self.gguf.add_author(metadata.author)
             if metadata.version is not None:
             if metadata.version is not None:
                 self.gguf.add_version(metadata.version)
                 self.gguf.add_version(metadata.version)
-            if metadata.url is not None:
-                self.gguf.add_url(metadata.url)
+            if metadata.organization is not None:
+                self.gguf.add_organization(metadata.organization)
+
+            if metadata.finetune is not None:
+                self.gguf.add_finetune(metadata.finetune)
+            if metadata.basename is not None:
+                self.gguf.add_basename(metadata.basename)
+
             if metadata.description is not None:
             if metadata.description is not None:
                 self.gguf.add_description(metadata.description)
                 self.gguf.add_description(metadata.description)
+            if metadata.quantized_by is not None:
+                self.gguf.add_quantized_by(metadata.quantized_by)
+
+            if metadata.size_label is not None:
+                self.gguf.add_size_label(metadata.size_label)
+
             if metadata.license is not None:
             if metadata.license is not None:
-                self.gguf.add_licence(metadata.license)
+                self.gguf.add_license(metadata.license)
+            if metadata.license_name is not None:
+                self.gguf.add_license_name(metadata.license_name)
+            if metadata.license_link is not None:
+                self.gguf.add_license_link(metadata.license_link)
+
+            if metadata.url is not None:
+                self.gguf.add_url(metadata.url)
+            if metadata.doi is not None:
+                self.gguf.add_doi(metadata.doi)
+            if metadata.uuid is not None:
+                self.gguf.add_uuid(metadata.uuid)
+            if metadata.repo_url is not None:
+                self.gguf.add_repo_url(metadata.repo_url)
+
             if metadata.source_url is not None:
             if metadata.source_url is not None:
                 self.gguf.add_source_url(metadata.source_url)
                 self.gguf.add_source_url(metadata.source_url)
-            if metadata.source_hf_repo is not None:
-                self.gguf.add_source_hf_repo(metadata.source_hf_repo)
+            if metadata.source_doi is not None:
+                self.gguf.add_source_doi(metadata.source_doi)
+            if metadata.source_uuid is not None:
+                self.gguf.add_source_uuid(metadata.source_uuid)
+            if metadata.source_repo_url is not None:
+                self.gguf.add_source_repo_url(metadata.source_repo_url)
+
+            if metadata.base_models is not None:
+                self.gguf.add_base_model_count(len(metadata.base_models))
+                for key, base_model_entry in enumerate(metadata.base_models):
+                    if "name" in base_model_entry:
+                        self.gguf.add_base_model_name(key, base_model_entry["name"])
+                    if "author" in base_model_entry:
+                        self.gguf.add_base_model_author(key, base_model_entry["author"])
+                    if "version" in base_model_entry:
+                        self.gguf.add_base_model_version(key, base_model_entry["version"])
+                    if "organization" in base_model_entry:
+                        self.gguf.add_base_model_organization(key, base_model_entry["organization"])
+                    if "url" in base_model_entry:
+                        self.gguf.add_base_model_url(key, base_model_entry["url"])
+                    if "doi" in base_model_entry:
+                        self.gguf.add_base_model_doi(key, base_model_entry["doi"])
+                    if "uuid" in base_model_entry:
+                        self.gguf.add_base_model_uuid(key, base_model_entry["uuid"])
+                    if "repo_url" in base_model_entry:
+                        self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])
+
+            if metadata.tags is not None:
+                self.gguf.add_tags(metadata.tags)
+            if metadata.languages is not None:
+                self.gguf.add_languages(metadata.languages)
+            if metadata.datasets is not None:
+                self.gguf.add_datasets(metadata.datasets)
 
 
     def add_meta_arch(self, params: Params) -> None:
     def add_meta_arch(self, params: Params) -> None:
         # Metadata About The Neural Architecture Itself
         # Metadata About The Neural Architecture Itself
@@ -944,7 +965,7 @@ class OutputFile:
     @staticmethod
     @staticmethod
     def write_vocab_only(
     def write_vocab_only(
         fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
         fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
-        endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None,
+        endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: gguf.Metadata | None = None,
     ) -> None:
     ) -> None:
         check_vocab_size(params, vocab, pad_vocab=pad_vocab)
         check_vocab_size(params, vocab, pad_vocab=pad_vocab)
 
 
@@ -978,7 +999,7 @@ class OutputFile:
         fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
         fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
         concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
         concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
         pad_vocab: bool = False,
         pad_vocab: bool = False,
-        metadata: Metadata | None = None,
+        metadata: gguf.Metadata | None = None,
     ) -> None:
     ) -> None:
         check_vocab_size(params, vocab, pad_vocab=pad_vocab)
         check_vocab_size(params, vocab, pad_vocab=pad_vocab)
 
 
@@ -1021,35 +1042,32 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
     raise ValueError(f"Unexpected combination of types: {name_to_type}")
     raise ValueError(f"Unexpected combination of types: {name_to_type}")
 
 
 
 
-def model_parameter_count(model: LazyModel) -> int:
-    total_model_parameters = 0
-    for i, (name, lazy_tensor) in enumerate(model.items()):
-        sum_weights_in_tensor = 1
+def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]]) -> tuple[int, int, int]:
+    total_params = 0
+    shared_params = 0
+    expert_params = 0
+
+    for name, lazy_tensor in tensors:
+        # We don't need these
+        if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
+            continue
+
+        # Got A Tensor
+        sum_weights_in_tensor: int = 1
+
+        # Tensor Volume
         for dim in lazy_tensor.shape:
         for dim in lazy_tensor.shape:
             sum_weights_in_tensor *= dim
             sum_weights_in_tensor *= dim
-        total_model_parameters += sum_weights_in_tensor
-    return total_model_parameters
-
-
-def model_parameter_count_rounded_notation(model_params_count: int) -> str:
-    if model_params_count > 1e12 :
-        # Trillions Of Parameters
-        scaled_model_params = model_params_count * 1e-12
-        scale_suffix = "T"
-    elif model_params_count > 1e9 :
-        # Billions Of Parameters
-        scaled_model_params = model_params_count * 1e-9
-        scale_suffix = "B"
-    elif model_params_count > 1e6 :
-        # Millions Of Parameters
-        scaled_model_params = model_params_count * 1e-6
-        scale_suffix = "M"
-    else:
-        # Thousands Of Parameters
-        scaled_model_params = model_params_count * 1e-3
-        scale_suffix = "K"
 
 
-    return f"{round(scaled_model_params)}{scale_suffix}"
+        if ".experts." in name:
+            if ".experts.0." in name:
+                expert_params += sum_weights_in_tensor
+        else:
+            shared_params += sum_weights_in_tensor
+
+        total_params += sum_weights_in_tensor
+
+    return total_params, shared_params, expert_params
 
 
 
 
 def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
 def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
@@ -1231,34 +1249,24 @@ class VocabFactory:
         return vocab, special_vocab
         return vocab, special_vocab
 
 
 
 
-def default_convention_outfile(file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> str:
-    quantization = {
+def default_convention_outfile(file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> str:
+    name = metadata.name if metadata.name is not None else None
+    basename = metadata.basename if metadata.basename is not None else None
+    finetune = metadata.finetune if metadata.finetune is not None else None
+    version = metadata.version if metadata.version is not None else None
+    size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(*model_params_count, expert_count=expert_count or 0)
+
+    output_type = {
         GGMLFileType.AllF32:    "F32",
         GGMLFileType.AllF32:    "F32",
         GGMLFileType.MostlyF16: "F16",
         GGMLFileType.MostlyF16: "F16",
         GGMLFileType.MostlyQ8_0: "Q8_0",
         GGMLFileType.MostlyQ8_0: "Q8_0",
     }[file_type]
     }[file_type]
 
 
-    parameters = model_parameter_count_rounded_notation(model_params_count)
-
-    expert_count = ""
-    if params.n_experts is not None:
-        expert_count = f"{params.n_experts}x"
-
-    version = ""
-    if metadata is not None and metadata.version is not None:
-        version = f"-{metadata.version}"
+    return gguf.naming_convention(name, basename, finetune, version, size_label, output_type)
 
 
-    name = "ggml-model"
-    if metadata is not None and metadata.name is not None:
-        name = metadata.name
-    elif params.path_model is not None:
-        name = params.path_model.name
 
 
-    return f"{name}{version}-{expert_count}{parameters}-{quantization}"
-
-
-def default_outfile(model_paths: list[Path], file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> Path:
-    default_filename = default_convention_outfile(file_type, params, model_params_count, metadata)
+def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> Path:
+    default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata)
     ret = model_paths[0].parent / f"{default_filename}.gguf"
     ret = model_paths[0].parent / f"{default_filename}.gguf"
     if ret in model_paths:
     if ret in model_paths:
         logger.error(
         logger.error(
@@ -1297,8 +1305,9 @@ def main(args_in: list[str] | None = None) -> None:
     parser.add_argument("--pad-vocab",    action="store_true",    help="add pad tokens when model vocab expects more than tokenizer metadata provides")
     parser.add_argument("--pad-vocab",    action="store_true",    help="add pad tokens when model vocab expects more than tokenizer metadata provides")
     parser.add_argument("--skip-unknown", action="store_true",    help="skip unknown tensor names instead of failing")
     parser.add_argument("--skip-unknown", action="store_true",    help="skip unknown tensor names instead of failing")
     parser.add_argument("--verbose",      action="store_true",    help="increase output verbosity")
     parser.add_argument("--verbose",      action="store_true",    help="increase output verbosity")
-    parser.add_argument("--metadata",     type=Path,              help="Specify the path for a metadata file")
+    parser.add_argument("--metadata",     type=Path,              help="Specify the path for an authorship metadata override file")
     parser.add_argument("--get-outfile",  action="store_true",    help="get calculated default outfile name")
     parser.add_argument("--get-outfile",  action="store_true",    help="get calculated default outfile name")
+    parser.add_argument("--model-name",   type=str, default=None, help="name of the model")
 
 
     args = parser.parse_args(args_in)
     args = parser.parse_args(args_in)
 
 
@@ -1310,32 +1319,36 @@ def main(args_in: list[str] | None = None) -> None:
     else:
     else:
         logging.basicConfig(level=logging.INFO)
         logging.basicConfig(level=logging.INFO)
 
 
-    metadata = Metadata.load(args.metadata)
+    model_name = args.model_name
+    dir_model = args.model
+
+    metadata = gguf.Metadata.load(args.metadata, dir_model, model_name)
 
 
     if args.get_outfile:
     if args.get_outfile:
-        model_plus = load_some_model(args.model)
+        model_plus = load_some_model(dir_model)
         params = Params.load(model_plus)
         params = Params.load(model_plus)
-        model   = convert_model_names(model_plus.model, params, args.skip_unknown)
-        model_params_count = model_parameter_count(model_plus.model)
-        ftype   = pick_output_type(model, args.outtype)
-        print(f"{default_convention_outfile(ftype, params, model_params_count, metadata)}") # noqa: NP100
+        model = convert_model_names(model_plus.model, params, args.skip_unknown)
+        model_params_count = per_model_weight_count_estimation(model_plus.model.items())
+        ftype = pick_output_type(model, args.outtype)
+
+        if (metadata is None or metadata.name is None) and params.path_model is not None:
+            metadata.name = params.path_model.name
+
+        print(f"{default_convention_outfile(ftype, params.n_experts, model_params_count, metadata)}") # noqa: NP100
         return
         return
 
 
     if args.no_vocab and args.vocab_only:
     if args.no_vocab and args.vocab_only:
         raise ValueError("--vocab-only does not make sense with --no-vocab")
         raise ValueError("--vocab-only does not make sense with --no-vocab")
 
 
     if args.dump_single:
     if args.dump_single:
-        model_plus = lazy_load_file(args.model)
+        model_plus = lazy_load_file(dir_model)
         do_dump_model(model_plus)
         do_dump_model(model_plus)
         return
         return
 
 
     if not args.vocab_only:
     if not args.vocab_only:
-        model_plus = load_some_model(args.model)
+        model_plus = load_some_model(dir_model)
     else:
     else:
-        model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
-
-    model_params_count = model_parameter_count(model_plus.model)
-    logger.info(f"model parameters count : {model_params_count} ({model_parameter_count_rounded_notation(model_params_count)})")
+        model_plus = ModelPlus(model = {}, paths = [dir_model / 'dummy'], format = 'none', vocab = None)
 
 
     if args.dump:
     if args.dump:
         do_dump_model(model_plus)
         do_dump_model(model_plus)
@@ -1368,7 +1381,7 @@ def main(args_in: list[str] | None = None) -> None:
         logger.info(f"params = {params}")
         logger.info(f"params = {params}")
 
 
     model_parent_path = model_plus.paths[0].parent
     model_parent_path = model_plus.paths[0].parent
-    vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
+    vocab_path = Path(args.vocab_dir or dir_model or model_parent_path)
     vocab_factory = VocabFactory(vocab_path)
     vocab_factory = VocabFactory(vocab_path)
     vocab_types = None if args.no_vocab else args.vocab_type.split(",")
     vocab_types = None if args.no_vocab else args.vocab_type.split(",")
     vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path)
     vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path)
@@ -1399,13 +1412,21 @@ def main(args_in: list[str] | None = None) -> None:
 
 
     assert params is not None
     assert params is not None
 
 
+    if metadata.name is None and params.path_model is not None:
+        metadata.name = params.path_model.name
+
+    model_params_count = per_model_weight_count_estimation(model_plus.model.items())
+    logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count[0])})")
+
     logger.info(f"Vocab info: {vocab}")
     logger.info(f"Vocab info: {vocab}")
     logger.info(f"Special vocab info: {special_vocab}")
     logger.info(f"Special vocab info: {special_vocab}")
     model   = model_plus.model
     model   = model_plus.model
     model   = convert_model_names(model, params, args.skip_unknown)
     model   = convert_model_names(model, params, args.skip_unknown)
     ftype   = pick_output_type(model, args.outtype)
     ftype   = pick_output_type(model, args.outtype)
     model   = convert_to_output_type(model, ftype)
     model   = convert_to_output_type(model, ftype)
-    outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata)
+    outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata)
+
+    metadata.size_label = gguf.size_label(*model_params_count, expert_count=params.n_experts or 0)
 
 
     params.ftype = ftype
     params.ftype = ftype
     logger.info(f"Writing {outfile}, format {ftype}")
     logger.info(f"Writing {outfile}, format {ftype}")

+ 8 - 0
gguf-py/README.md

@@ -78,5 +78,13 @@ python -m build
 python -m twine upload dist/*
 python -m twine upload dist/*
 ```
 ```
 
 
+## Run Unit Tests
+
+From root of this repository you can run this command to run all the unit tests
+
+```bash
+python -m unittest discover ./gguf-py -v
+```
+
 ## TODO
 ## TODO
 - [ ] Include conversion scripts as command line entry points in this package.
 - [ ] Include conversion scripts as command line entry points in this package.

+ 2 - 0
gguf-py/gguf/__init__.py

@@ -5,3 +5,5 @@ from .gguf_writer import *
 from .quants import *
 from .quants import *
 from .tensor_mapping import *
 from .tensor_mapping import *
 from .vocab import *
 from .vocab import *
+from .utility import *
+from .metadata import *

+ 54 - 14
gguf-py/gguf/constants.py

@@ -19,19 +19,60 @@ GGML_QUANT_VERSION     = 2  # GGML_QNT_VERSION from ggml.h
 
 
 class Keys:
 class Keys:
     class General:
     class General:
-        TYPE                 = "general.type"
-        ARCHITECTURE         = "general.architecture"
-        QUANTIZATION_VERSION = "general.quantization_version"
-        ALIGNMENT            = "general.alignment"
-        NAME                 = "general.name"
-        AUTHOR               = "general.author"
-        VERSION              = "general.version"
-        URL                  = "general.url"
-        DESCRIPTION          = "general.description"
-        LICENSE              = "general.license"
-        SOURCE_URL           = "general.source.url"
-        SOURCE_HF_REPO       = "general.source.huggingface.repository"
-        FILE_TYPE            = "general.file_type"
+        TYPE                       = "general.type"
+        ARCHITECTURE               = "general.architecture"
+        QUANTIZATION_VERSION       = "general.quantization_version"
+        ALIGNMENT                  = "general.alignment"
+        FILE_TYPE                  = "general.file_type"
+
+        # Authorship Metadata
+        NAME                       = "general.name"
+        AUTHOR                     = "general.author"
+        VERSION                    = "general.version"
+        ORGANIZATION               = "general.organization"
+
+        FINETUNE                   = "general.finetune"
+        BASENAME                   = "general.basename"
+
+        DESCRIPTION                = "general.description"
+        QUANTIZED_BY               = "general.quantized_by"
+
+        SIZE_LABEL                 = "general.size_label"
+
+        # Licensing details
+        LICENSE                    = "general.license"
+        LICENSE_NAME               = "general.license.name"
+        LICENSE_LINK               = "general.license.link"
+
+        # Typically represents the converted GGUF repo (Unless native)
+        URL                        = "general.url" # Model Website/Paper
+        DOI                        = "general.doi"
+        UUID                       = "general.uuid"
+        REPO_URL                   = "general.repo_url" # Model Source Repository (git/svn/etc...)
+
+        # Model Source during conversion
+        SOURCE_URL                 = "general.source.url" # Model Website/Paper
+        SOURCE_DOI                 = "general.source.doi"
+        SOURCE_UUID                = "general.source.uuid"
+        SOURCE_REPO_URL            = "general.source.repo_url" # Model Source Repository (git/svn/etc...)
+
+        # Base Model Source. There can be more than one source if it's a merged
+        # model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in
+        # tracing linage of models as it is finetuned or merged over time.
+        BASE_MODEL_COUNT           = "general.base_model.count"
+        BASE_MODEL_NAME            = "general.base_model.{id}.name"
+        BASE_MODEL_AUTHOR          = "general.base_model.{id}.author"
+        BASE_MODEL_VERSION         = "general.base_model.{id}.version"
+        BASE_MODEL_ORGANIZATION    = "general.base_model.{id}.organization"
+        BASE_MODEL_URL             = "general.base_model.{id}.url" # Model Website/Paper
+        BASE_MODEL_DOI             = "general.base_model.{id}.doi"
+        BASE_MODEL_UUID            = "general.base_model.{id}.uuid"
+        BASE_MODEL_REPO_URL        = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
+
+        # Array based KV stores
+        TAGS                       = "general.tags"
+        LANGUAGES                  = "general.languages"
+        DATASETS                   = "general.datasets"
 
 
     class LLM:
     class LLM:
         VOCAB_SIZE                        = "{arch}.vocab_size"
         VOCAB_SIZE                        = "{arch}.vocab_size"
@@ -1233,7 +1274,6 @@ KEY_GENERAL_URL                  = Keys.General.URL
 KEY_GENERAL_DESCRIPTION          = Keys.General.DESCRIPTION
 KEY_GENERAL_DESCRIPTION          = Keys.General.DESCRIPTION
 KEY_GENERAL_LICENSE              = Keys.General.LICENSE
 KEY_GENERAL_LICENSE              = Keys.General.LICENSE
 KEY_GENERAL_SOURCE_URL           = Keys.General.SOURCE_URL
 KEY_GENERAL_SOURCE_URL           = Keys.General.SOURCE_URL
-KEY_GENERAL_SOURCE_HF_REPO       = Keys.General.SOURCE_HF_REPO
 KEY_GENERAL_FILE_TYPE            = Keys.General.FILE_TYPE
 KEY_GENERAL_FILE_TYPE            = Keys.General.FILE_TYPE
 
 
 # LLM
 # LLM

+ 140 - 18
gguf-py/gguf/gguf_writer.py

@@ -7,6 +7,7 @@ import struct
 import tempfile
 import tempfile
 from dataclasses import dataclass
 from dataclasses import dataclass
 from enum import Enum, auto
 from enum import Enum, auto
+from math import prod
 from pathlib import Path
 from pathlib import Path
 from io import BufferedWriter
 from io import BufferedWriter
 from typing import IO, Any, Sequence, Mapping
 from typing import IO, Any, Sequence, Mapping
@@ -106,6 +107,53 @@ class GGUFWriter:
 
 
         self.add_architecture()
         self.add_architecture()
 
 
+    def get_total_parameter_count(self) -> tuple[int, int, int, int]:
+        total_params = 0
+        shared_params = 0
+        expert_params = 0
+
+        expert_sum = 0
+        n_expert_tensors = 0
+
+        last_lora_a: tuple[str, TensorInfo] | None = None
+
+        for tensors in self.tensors:
+            for name, info in tensors.items():
+
+                shape = info.shape
+
+                if name.endswith(".lora_a"):
+                    last_lora_a = (name, info)
+                    continue
+                elif name.endswith(".lora_b"):
+                    if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
+                        # Bail when the LoRA pair can't be found trivially
+                        logger.warning("can't measure LoRA size correctly, tensor order is unusual")
+                        return 0, 0, 0, 0
+                    else:
+                        shape = (*shape[:-1], last_lora_a[1].shape[-1])
+
+                size = prod(shape)
+
+                if "_exps." in name:
+                    expert_params += (size // shape[-3])
+                    expert_sum += shape[-3]
+                    n_expert_tensors += 1
+                else:
+                    shared_params += size
+
+                total_params += size
+
+        # Hopefully this should work even for variable-expert-count models
+        expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
+
+        # Negate the total to signal it's likely not exact
+        if last_lora_a is not None:
+            total_params = -total_params
+
+        # NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
+        return total_params, shared_params, expert_params, expert_count
+
     def format_shard_names(self, path: Path) -> list[Path]:
     def format_shard_names(self, path: Path) -> list[Path]:
         if len(self.tensors) == 1:
         if len(self.tensors) == 1:
             return [path]
             return [path]
@@ -115,6 +163,7 @@ class GGUFWriter:
         if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
         if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
             # allow calling this multiple times as long as the path is the same
             # allow calling this multiple times as long as the path is the same
             return
             return
+
         if self.state is not WriterState.NO_FILE:
         if self.state is not WriterState.NO_FILE:
             raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
             raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
 
 
@@ -136,6 +185,8 @@ class GGUFWriter:
 
 
         if self.dry_run:
         if self.dry_run:
             logger.info("Dry run, not writing files")
             logger.info("Dry run, not writing files")
+            for name in filenames:
+                print(name)  # noqa: NP100
             exit()
             exit()
 
 
         return filenames
         return filenames
@@ -430,43 +481,114 @@ class GGUFWriter:
     def add_architecture(self) -> None:
     def add_architecture(self) -> None:
         self.add_string(Keys.General.ARCHITECTURE, self.arch)
         self.add_string(Keys.General.ARCHITECTURE, self.arch)
 
 
+    def add_quantization_version(self, quantization_version: int) -> None:
+        self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
+
+    def add_custom_alignment(self, alignment: int) -> None:
+        self.data_alignment = alignment
+        self.add_uint32(Keys.General.ALIGNMENT, alignment)
+
+    def add_file_type(self, ftype: int) -> None:
+        self.add_uint32(Keys.General.FILE_TYPE, ftype)
+
+    def add_name(self, name: str) -> None:
+        self.add_string(Keys.General.NAME, name)
+
     def add_author(self, author: str) -> None:
     def add_author(self, author: str) -> None:
         self.add_string(Keys.General.AUTHOR, author)
         self.add_string(Keys.General.AUTHOR, author)
 
 
     def add_version(self, version: str) -> None:
     def add_version(self, version: str) -> None:
         self.add_string(Keys.General.VERSION, version)
         self.add_string(Keys.General.VERSION, version)
 
 
-    def add_tensor_data_layout(self, layout: str) -> None:
-        self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
+    def add_organization(self, organization: str) -> None:
+        self.add_string(Keys.General.ORGANIZATION, organization)
 
 
-    def add_url(self, url: str) -> None:
-        self.add_string(Keys.General.URL, url)
+    def add_finetune(self, finetune: str) -> None:
+        self.add_string(Keys.General.FINETUNE, finetune)
+
+    def add_basename(self, basename: str) -> None:
+        self.add_string(Keys.General.BASENAME, basename)
 
 
     def add_description(self, description: str) -> None:
     def add_description(self, description: str) -> None:
         self.add_string(Keys.General.DESCRIPTION, description)
         self.add_string(Keys.General.DESCRIPTION, description)
 
 
-    def add_licence(self, licence: str) -> None:
-        self.add_string(Keys.General.LICENSE, licence)
+    def add_quantized_by(self, quantized: str) -> None:
+        self.add_string(Keys.General.QUANTIZED_BY, quantized)
+
+    def add_size_label(self, size_label: str) -> None:
+        self.add_string(Keys.General.SIZE_LABEL, size_label)
+
+    def add_license(self, license: str) -> None:
+        self.add_string(Keys.General.LICENSE, license)
+
+    def add_license_name(self, license: str) -> None:
+        self.add_string(Keys.General.LICENSE_NAME, license)
+
+    def add_license_link(self, license: str) -> None:
+        self.add_string(Keys.General.LICENSE_LINK, license)
+
+    def add_url(self, url: str) -> None:
+        self.add_string(Keys.General.URL, url)
+
+    def add_doi(self, doi: str) -> None:
+        self.add_string(Keys.General.DOI, doi)
+
+    def add_uuid(self, uuid: str) -> None:
+        self.add_string(Keys.General.UUID, uuid)
+
+    def add_repo_url(self, repo_url: str) -> None:
+        self.add_string(Keys.General.REPO_URL, repo_url)
 
 
     def add_source_url(self, url: str) -> None:
     def add_source_url(self, url: str) -> None:
         self.add_string(Keys.General.SOURCE_URL, url)
         self.add_string(Keys.General.SOURCE_URL, url)
 
 
-    def add_source_hf_repo(self, repo: str) -> None:
-        self.add_string(Keys.General.SOURCE_HF_REPO, repo)
+    def add_source_doi(self, doi: str) -> None:
+        self.add_string(Keys.General.SOURCE_DOI, doi)
 
 
-    def add_file_type(self, ftype: int) -> None:
-        self.add_uint32(Keys.General.FILE_TYPE, ftype)
+    def add_source_uuid(self, uuid: str) -> None:
+        self.add_string(Keys.General.SOURCE_UUID, uuid)
 
 
-    def add_name(self, name: str) -> None:
-        self.add_string(Keys.General.NAME, name)
+    def add_source_repo_url(self, repo_url: str) -> None:
+        self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
 
 
-    def add_quantization_version(self, quantization_version: int) -> None:
-        self.add_uint32(
-            Keys.General.QUANTIZATION_VERSION, quantization_version)
+    def add_base_model_count(self, source_count: int) -> None:
+        self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
 
 
-    def add_custom_alignment(self, alignment: int) -> None:
-        self.data_alignment = alignment
-        self.add_uint32(Keys.General.ALIGNMENT, alignment)
+    def add_base_model_name(self, source_id: int, name: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
+
+    def add_base_model_author(self, source_id: int, author: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
+
+    def add_base_model_version(self, source_id: int, version: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
+
+    def add_base_model_organization(self, source_id: int, organization: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
+
+    def add_base_model_url(self, source_id: int, url: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
+
+    def add_base_model_doi(self, source_id: int, doi: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
+
+    def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
+
+    def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
+
+    def add_tags(self, tags: Sequence[str]) -> None:
+        self.add_array(Keys.General.TAGS, tags)
+
+    def add_languages(self, languages: Sequence[str]) -> None:
+        self.add_array(Keys.General.LANGUAGES, languages)
+
+    def add_datasets(self, datasets: Sequence[str]) -> None:
+        self.add_array(Keys.General.DATASETS, datasets)
+
+    def add_tensor_data_layout(self, layout: str) -> None:
+        self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
 
 
     def add_vocab_size(self, size: int) -> None:
     def add_vocab_size(self, size: int) -> None:
         self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
         self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)

+ 485 - 0
gguf-py/gguf/metadata.py

@@ -0,0 +1,485 @@
+from __future__ import annotations
+
+import re
+import json
+import yaml
+import logging
+from pathlib import Path
+from typing import Any, Literal, Optional
+from dataclasses import dataclass
+
+from .constants import Keys
+
+import gguf
+
+logger = logging.getLogger("metadata")
+
+
+@dataclass
+class Metadata:
+    # Authorship Metadata to be written to GGUF KV Store
+    name: Optional[str] = None
+    author: Optional[str] = None
+    version: Optional[str] = None
+    organization: Optional[str] = None
+    finetune: Optional[str] = None
+    basename: Optional[str] = None
+    description: Optional[str] = None
+    quantized_by: Optional[str] = None
+    size_label: Optional[str] = None
+    url: Optional[str] = None
+    doi: Optional[str] = None
+    uuid: Optional[str] = None
+    repo_url: Optional[str] = None
+    source_url: Optional[str] = None
+    source_doi: Optional[str] = None
+    source_uuid: Optional[str] = None
+    source_repo_url: Optional[str] = None
+    license: Optional[str] = None
+    license_name: Optional[str] = None
+    license_link: Optional[str] = None
+    base_models: Optional[list[dict]] = None
+    tags: Optional[list[str]] = None
+    languages: Optional[list[str]] = None
+    datasets: Optional[list[str]] = None
+
+    @staticmethod
+    def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
+        # This grabs as many contextual authorship metadata as possible from the model repository
+        # making any conversion as required to match the gguf kv store metadata format
+        # as well as giving users the ability to override any authorship metadata that may be incorrect
+
+        # Create a new Metadata instance
+        metadata = Metadata()
+
+        model_card = Metadata.load_model_card(model_path)
+        hf_params = Metadata.load_hf_parameters(model_path)
+
+        # heuristics
+        metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
+
+        # Metadata Override File Provided
+        # This is based on LLM_KV_NAMES mapping in llama.cpp
+        metadata_override = Metadata.load_metadata_override(metadata_override_path)
+
+        metadata.author          = metadata_override.get(Keys.General.AUTHOR,          metadata.author)
+        metadata.version         = metadata_override.get(Keys.General.VERSION,         metadata.version)
+        metadata.organization    = metadata_override.get(Keys.General.ORGANIZATION,    metadata.organization)
+
+        metadata.finetune        = metadata_override.get(Keys.General.FINETUNE,        metadata.finetune)
+        metadata.basename        = metadata_override.get(Keys.General.BASENAME,        metadata.basename)
+
+        metadata.description     = metadata_override.get(Keys.General.DESCRIPTION,     metadata.description)
+        metadata.quantized_by    = metadata_override.get(Keys.General.QUANTIZED_BY,    metadata.quantized_by)
+
+        metadata.size_label      = metadata_override.get(Keys.General.SIZE_LABEL,      metadata.size_label)
+        metadata.license_name    = metadata_override.get(Keys.General.LICENSE_NAME,    metadata.license_name)
+        metadata.license_link    = metadata_override.get(Keys.General.LICENSE_LINK,    metadata.license_link)
+
+        metadata.url             = metadata_override.get(Keys.General.URL,             metadata.url)
+        metadata.doi             = metadata_override.get(Keys.General.DOI,             metadata.doi)
+        metadata.uuid            = metadata_override.get(Keys.General.UUID,            metadata.uuid)
+        metadata.repo_url        = metadata_override.get(Keys.General.REPO_URL,        metadata.repo_url)
+
+        metadata.source_url      = metadata_override.get(Keys.General.SOURCE_URL,      metadata.source_url)
+        metadata.source_doi      = metadata_override.get(Keys.General.SOURCE_DOI,      metadata.source_doi)
+        metadata.source_uuid     = metadata_override.get(Keys.General.SOURCE_UUID,     metadata.source_uuid)
+        metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
+
+        # Base Models is received here as an array of models
+        metadata.base_models     = metadata_override.get("general.base_models",        metadata.base_models)
+
+        metadata.tags            = metadata_override.get(Keys.General.TAGS,            metadata.tags)
+        metadata.languages       = metadata_override.get(Keys.General.LANGUAGES,       metadata.languages)
+        metadata.datasets        = metadata_override.get(Keys.General.DATASETS,        metadata.datasets)
+
+        # Direct Metadata Override (via direct cli argument)
+        if model_name is not None:
+            metadata.name = model_name
+
+        return metadata
+
+    @staticmethod
+    def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
+        if metadata_override_path is None or not metadata_override_path.is_file():
+            return {}
+
+        with open(metadata_override_path, "r", encoding="utf-8") as f:
+            return json.load(f)
+
+    @staticmethod
+    def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
+        if model_path is None or not model_path.is_dir():
+            return {}
+
+        model_card_path = model_path / "README.md"
+
+        if not model_card_path.is_file():
+            return {}
+
+        # The model card metadata is assumed to always be in YAML
+        # ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
+        with open(model_card_path, "r", encoding="utf-8") as f:
+            if f.readline() == "---\n":
+                raw = f.read().partition("---\n")[0]
+                data = yaml.safe_load(raw)
+                if isinstance(data, dict):
+                    return data
+                else:
+                    logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
+                    return {}
+            else:
+                return {}
+
+    @staticmethod
+    def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
+        if model_path is None or not model_path.is_dir():
+            return {}
+
+        config_path = model_path / "config.json"
+
+        if not config_path.is_file():
+            return {}
+
+        with open(config_path, "r", encoding="utf-8") as f:
+            return json.load(f)
+
+    @staticmethod
+    def id_to_title(string):
+        # Convert capitalization into title form unless acronym or version number
+        return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
+
+    @staticmethod
+    def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
+        # Huggingface often store model id as '<org>/<model name>'
+        # so let's parse it and apply some heuristics if possible for model name components
+
+        if model_id is None:
+            # model ID missing
+            return None, None, None, None, None, None
+
+        if ' ' in model_id:
+            # model ID is actually a normal human sentence
+            # which means its most likely a normal model name only
+            # not part of the hugging face naming standard, but whatever
+            return model_id, None, None, None, None, None
+
+        if '/' in model_id:
+            # model ID (huggingface style)
+            org_component, model_full_name_component = model_id.split('/', 1)
+        else:
+            # model ID but missing org components
+            org_component, model_full_name_component = None, model_id
+
+        # Check if we erroneously matched against './' or '../' etc...
+        if org_component is not None and org_component[0] == '.':
+            org_component = None
+
+        name_parts: list[str] = model_full_name_component.split('-')
+        name_types: list[
+            set[Literal["basename", "size_label", "finetune", "version", "type"]]
+        ] = [set() for _ in name_parts]
+
+        # Annotate the name
+        for i, part in enumerate(name_parts):
+            # Version
+            if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
+                name_types[i].add("version")
+            # Quant type (should not be there for base models, but still annotated)
+            elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
+                name_types[i].add("type")
+                name_parts[i] = part.upper()
+            # Model size
+            elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
+                part = part.replace("_", ".")
+                # Handle weird bloom-7b1 notation
+                if part[-1].isdecimal():
+                    part = part[:-2] + "." + part[-1] + part[-2]
+                # Normalize the size suffixes
+                if len(part) > 1 and part[-2].isdecimal():
+                    if part[-1] in "kmbt":
+                        part = part[:-1] + part[-1].upper()
+                if total_params != 0:
+                    try:
+                        label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
+                        # Only use it as a size label if it's close or bigger than the model size
+                        # Note that LoRA adapters don't necessarily include all layers,
+                        # so this is why bigger label sizes are accepted.
+                        # Do not use the size label when it's smaller than 1/8 of the model size
+                        if (total_params < 0 and label_params < abs(total_params) // 8) or (
+                            # Check both directions when the current model isn't a LoRA adapter
+                            total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
+                        ):
+                            # Likely a context length
+                            name_types[i].add("finetune")
+                            # Lowercase the size when it's a context length
+                            part = part[:-1] + part[-1].lower()
+                    except ValueError:
+                        # Failed to convert the size label to float, use it anyway
+                        pass
+                if len(name_types[i]) == 0:
+                    name_types[i].add("size_label")
+                name_parts[i] = part
+            # Some easy to recognize finetune names
+            elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
+                name_types[i].add("finetune")
+                if part.lower() == "lora":
+                    name_parts[i] = "LoRA"
+
+        at_start = True
+        # Find the basename through the annotated name
+        for part, t in zip(name_parts, name_types):
+            if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
+                t.add("basename")
+            else:
+                if at_start:
+                    at_start = False
+                if len(t) == 0:
+                    t.add("finetune")
+
+        # Remove the basename annotation from trailing version
+        for part, t in zip(reversed(name_parts), reversed(name_types)):
+            if "basename" in t:
+                if len(t) > 1:
+                    t.remove("basename")
+            else:
+                break
+
+        basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
+        size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None
+        finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
+        # TODO: should the basename version always be excluded?
+        # TODO: should multiple versions be joined together?
+        version = ([v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t] or [None])[-1]
+
+        if size_label is None and finetune is None and version is None:
+            # Too ambiguous, output nothing
+            basename = None
+
+        return model_full_name_component, org_component, basename, finetune, version, size_label
+
+    @staticmethod
+    def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
+        # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
+
+        # Model Card Heuristics
+        ########################
+        if model_card is not None:
+
+            if "model_name" in model_card and metadata.name is None:
+                # Not part of huggingface model card standard but notice some model creator using it
+                # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
+                metadata.name = model_card.get("model_name")
+
+            if "model_creator" in model_card and metadata.author is None:
+                # Not part of huggingface model card standard but notice some model creator using it
+                # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
+                metadata.author = model_card.get("model_creator")
+
+            if "model_type" in model_card and metadata.basename is None:
+                # Not part of huggingface model card standard but notice some model creator using it
+                # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
+                metadata.basename = model_card.get("model_type")
+
+            if "base_model" in model_card:
+                # This represents the parent models that this is based on
+                # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
+                # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
+                metadata_base_models = []
+                base_model_value = model_card.get("base_model", None)
+
+                if base_model_value is not None:
+                    if isinstance(base_model_value, str):
+                        metadata_base_models.append(base_model_value)
+                    elif isinstance(base_model_value, list):
+                        metadata_base_models.extend(base_model_value)
+
+                if metadata.base_models is None:
+                    metadata.base_models = []
+
+                for model_id in metadata_base_models:
+                    # NOTE: model size of base model is assumed to be similar to the size of the current model
+                    model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
+                    base_model = {}
+                    if model_full_name_component is not None:
+                        base_model["name"] = Metadata.id_to_title(model_full_name_component)
+                    if org_component is not None:
+                        base_model["organization"] = Metadata.id_to_title(org_component)
+                    if version is not None:
+                        base_model["version"] = version
+                    if org_component is not None and model_full_name_component is not None:
+                        base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
+                    metadata.base_models.append(base_model)
+
+            if "license" in model_card and metadata.license is None:
+                metadata.license = model_card.get("license")
+
+            if "license_name" in model_card and metadata.license_name is None:
+                metadata.license_name = model_card.get("license_name")
+
+            if "license_link" in model_card and metadata.license_link is None:
+                metadata.license_link = model_card.get("license_link")
+
+            tags_value = model_card.get("tags", None)
+            if tags_value is not None:
+
+                if metadata.tags is None:
+                    metadata.tags = []
+
+                if isinstance(tags_value, str):
+                    metadata.tags.append(tags_value)
+                elif isinstance(tags_value, list):
+                    metadata.tags.extend(tags_value)
+
+            pipeline_tags_value = model_card.get("pipeline_tag", None)
+            if pipeline_tags_value is not None:
+
+                if metadata.tags is None:
+                    metadata.tags = []
+
+                if isinstance(pipeline_tags_value, str):
+                    metadata.tags.append(pipeline_tags_value)
+                elif isinstance(pipeline_tags_value, list):
+                    metadata.tags.extend(pipeline_tags_value)
+
+            language_value = model_card.get("languages", model_card.get("language", None))
+            if language_value is not None:
+
+                if metadata.languages is None:
+                    metadata.languages = []
+
+                if isinstance(language_value, str):
+                    metadata.languages.append(language_value)
+                elif isinstance(language_value, list):
+                    metadata.languages.extend(language_value)
+
+            dataset_value = model_card.get("datasets", model_card.get("dataset", None))
+            if dataset_value is not None:
+
+                if metadata.datasets is None:
+                    metadata.datasets = []
+
+                if isinstance(dataset_value, str):
+                    metadata.datasets.append(dataset_value)
+                elif isinstance(dataset_value, list):
+                    metadata.datasets.extend(dataset_value)
+
+        # Hugging Face Parameter Heuristics
+        ####################################
+
+        if hf_params is not None:
+
+            hf_name_or_path = hf_params.get("_name_or_path")
+            if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
+                # Use _name_or_path only if its actually a model name and not some computer path
+                # e.g. 'meta-llama/Llama-2-7b-hf'
+                model_id = hf_name_or_path
+                model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
+                if metadata.name is None and model_full_name_component is not None:
+                    metadata.name = Metadata.id_to_title(model_full_name_component)
+                if metadata.organization is None and org_component is not None:
+                    metadata.organization = Metadata.id_to_title(org_component)
+                if metadata.basename is None and basename is not None:
+                    metadata.basename = basename
+                if metadata.finetune is None and finetune is not None:
+                    metadata.finetune = finetune
+                if metadata.version is None and version is not None:
+                    metadata.version = version
+                if metadata.size_label is None and size_label is not None:
+                    metadata.size_label = size_label
+
+        # Directory Folder Name Fallback Heuristics
+        ############################################
+        if model_path is not None:
+            model_id = model_path.name
+            model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
+            if metadata.name is None and model_full_name_component is not None:
+                metadata.name = Metadata.id_to_title(model_full_name_component)
+            if metadata.organization is None and org_component is not None:
+                metadata.organization = Metadata.id_to_title(org_component)
+            if metadata.basename is None and basename is not None:
+                metadata.basename = basename
+            if metadata.finetune is None and finetune is not None:
+                metadata.finetune = finetune
+            if metadata.version is None and version is not None:
+                metadata.version = version
+            if metadata.size_label is None and size_label is not None:
+                metadata.size_label = size_label
+
+        return metadata
+
+    def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
+        assert self.name is not None
+        gguf_writer.add_name(self.name)
+
+        if self.author is not None:
+            gguf_writer.add_author(self.author)
+        if self.version is not None:
+            gguf_writer.add_version(self.version)
+        if self.organization is not None:
+            gguf_writer.add_organization(self.organization)
+
+        if self.finetune is not None:
+            gguf_writer.add_finetune(self.finetune)
+        if self.basename is not None:
+            gguf_writer.add_basename(self.basename)
+
+        if self.description is not None:
+            gguf_writer.add_description(self.description)
+        if self.quantized_by is not None:
+            gguf_writer.add_quantized_by(self.quantized_by)
+
+        if self.size_label is not None:
+            gguf_writer.add_size_label(self.size_label)
+
+        if self.license is not None:
+            gguf_writer.add_license(self.license)
+        if self.license_name is not None:
+            gguf_writer.add_license_name(self.license_name)
+        if self.license_link is not None:
+            gguf_writer.add_license_link(self.license_link)
+
+        if self.url is not None:
+            gguf_writer.add_url(self.url)
+        if self.doi is not None:
+            gguf_writer.add_doi(self.doi)
+        if self.uuid is not None:
+            gguf_writer.add_uuid(self.uuid)
+        if self.repo_url is not None:
+            gguf_writer.add_repo_url(self.repo_url)
+
+        if self.source_url is not None:
+            gguf_writer.add_source_url(self.source_url)
+        if self.source_doi is not None:
+            gguf_writer.add_source_doi(self.source_doi)
+        if self.source_uuid is not None:
+            gguf_writer.add_source_uuid(self.source_uuid)
+        if self.source_repo_url is not None:
+            gguf_writer.add_source_repo_url(self.source_repo_url)
+
+        if self.base_models is not None:
+            gguf_writer.add_base_model_count(len(self.base_models))
+            for key, base_model_entry in enumerate(self.base_models):
+                if "name" in base_model_entry:
+                    gguf_writer.add_base_model_name(key, base_model_entry["name"])
+                if "author" in base_model_entry:
+                    gguf_writer.add_base_model_author(key, base_model_entry["author"])
+                if "version" in base_model_entry:
+                    gguf_writer.add_base_model_version(key, base_model_entry["version"])
+                if "organization" in base_model_entry:
+                    gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
+                if "url" in base_model_entry:
+                    gguf_writer.add_base_model_url(key, base_model_entry["url"])
+                if "doi" in base_model_entry:
+                    gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
+                if "uuid" in base_model_entry:
+                    gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
+                if "repo_url" in base_model_entry:
+                    gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
+
+        if self.tags is not None:
+            gguf_writer.add_tags(self.tags)
+        if self.languages is not None:
+            gguf_writer.add_languages(self.languages)
+        if self.datasets is not None:
+            gguf_writer.add_datasets(self.datasets)

+ 69 - 0
gguf-py/gguf/utility.py

@@ -0,0 +1,69 @@
+from __future__ import annotations
+
+from typing import Literal
+
+
+def fill_templated_filename(filename: str, output_type: str | None) -> str:
+    # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
+    ftype_lowercase: str = output_type.lower() if output_type is not None else ""
+    ftype_uppercase: str = output_type.upper() if output_type is not None else ""
+    return filename.format(ftype_lowercase,
+                           outtype=ftype_lowercase, ftype=ftype_lowercase,
+                           OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
+
+
+def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
+    if model_params_count > 1e12 :
+        # Trillions Of Parameters
+        scaled_model_params = model_params_count * 1e-12
+        scale_suffix = "T"
+    elif model_params_count > 1e9 :
+        # Billions Of Parameters
+        scaled_model_params = model_params_count * 1e-9
+        scale_suffix = "B"
+    elif model_params_count > 1e6 :
+        # Millions Of Parameters
+        scaled_model_params = model_params_count * 1e-6
+        scale_suffix = "M"
+    else:
+        # Thousands Of Parameters
+        scaled_model_params = model_params_count * 1e-3
+        scale_suffix = "K"
+
+    fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
+
+    return f"{scaled_model_params:.{fix}f}{scale_suffix}"
+
+
+def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
+
+    if expert_count > 0:
+        pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
+        size_class = f"{expert_count}x{pretty_size}"
+    else:
+        size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
+
+    return size_class
+
+
+def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
+    # Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
+
+    if base_name is not None:
+        name = base_name.strip().title().replace(' ', '-').replace('/', '-')
+    elif model_name is not None:
+        name = model_name.strip().title().replace(' ', '-').replace('/', '-')
+    else:
+        name = "ggml-model"
+
+    parameters = f"-{size_label}" if size_label is not None else ""
+
+    finetune = f"-{finetune_string.strip().title().replace(' ', '-')}" if finetune_string is not None else ""
+
+    version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
+
+    encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
+
+    kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
+
+    return f"{name}{parameters}{finetune}{version}{encoding}{kind}"

+ 1 - 0
gguf-py/pyproject.toml

@@ -22,6 +22,7 @@ classifiers = [
 python = ">=3.8"
 python = ">=3.8"
 numpy = ">=1.17"
 numpy = ">=1.17"
 tqdm = ">=4.27"
 tqdm = ">=4.27"
+pyyaml = ">=5.1"
 
 
 [tool.poetry.dev-dependencies]
 [tool.poetry.dev-dependencies]
 pytest = "^5.2"
 pytest = "^5.2"

+ 1 - 0
gguf-py/tests/__init__.py

@@ -0,0 +1 @@
+from .test_metadata import *

+ 0 - 7
gguf-py/tests/test_gguf.py

@@ -1,7 +0,0 @@
-import gguf  # noqa: F401  # pyright: ignore[reportUnusedImport]
-
-# TODO: add tests
-
-
-def test_write_gguf() -> None:
-    pass

+ 158 - 0
gguf-py/tests/test_metadata.py

@@ -0,0 +1,158 @@
+#!/usr/bin/env python3
+
+import unittest
+from pathlib import Path
+import os
+import sys
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
+    sys.path.insert(0, str(Path(__file__).parent.parent))
+
+import gguf
+
+
+class TestMetadataMethod(unittest.TestCase):
+
+    def test_id_to_title(self):
+        self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1")
+        self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B")
+        self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO")
+
+    def test_get_model_id_components(self):
+        # This is the basic standard form with organization marker
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
+                         ('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
+
+        # Similar to basic standard form but without organization marker
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
+                         ('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
+
+        # Missing version
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
+                         ('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
+
+        # Missing finetune
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
+                         ('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
+
+        # Base name and size label only
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"),
+                         ('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
+
+        # Base name and version only
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"),
+                         ('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
+
+        ## Edge Cases ##
+
+        # This is too ambiguous... best to err on caution and output nothing
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"),
+                         ('Mixtral', None, None, None, None, None))
+
+        # Basename has numbers mixed in and also size label provided. Must avoid capturing number in basename
+        self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
+                         ('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
+
+        # Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
+        self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
+                         ('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
+
+        # Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count
+        self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"),
+                         ('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B'))
+
+        # Check that it can handle a real model id with no version code
+        # Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count
+        self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4 * 10**9),
+                         ('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3', '4k-instruct', None, 'mini'))
+
+        # There is some legitimate models with only thousands of parameters
+        self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
+                         ('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
+
+        # None standard and not easy to disambiguate
+        self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
+                         ('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
+
+        # This is a real model_id where they append 2DPO to refer to Direct Preference Optimization
+        self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"),
+                         ('daybreak-kunoichi-2dpo-7b', 'crestf411', 'daybreak-kunoichi', '2dpo', None, '7B'))
+
+        # This is a real model id where the weight size has a decimal point
+        self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"),
+                         ('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B'))
+
+        # Uses an underscore in the size label
+        self.assertEqual(gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"),
+                         ('Refact-1_6B-fim', 'smallcloudai', 'Refact', 'fim', None, '1.6B'))
+
+        # Uses Iter3 for the version
+        self.assertEqual(gguf.Metadata.get_model_id_components("UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3"),
+                         ('Gemma-2-9B-It-SPPO-Iter3', 'UCLA-AGI', 'Gemma-2', 'It-SPPO', 'Iter3', '9B'))
+
+        # Has two potential versions in the basename
+        self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Hermes-2-Theta-Llama-3-8B"),
+                         ('Hermes-2-Theta-Llama-3-8B', 'NousResearch', 'Hermes-2-Theta-Llama-3', None, None, '8B'))
+
+        # Potential version in the basename
+        self.assertEqual(gguf.Metadata.get_model_id_components("SeaLLMs/SeaLLMs-v3-7B-Chat"),
+                         ('SeaLLMs-v3-7B-Chat', 'SeaLLMs', 'SeaLLMs-v3', 'Chat', None, '7B'))
+
+        # Underscore in the basename, and 1m for the context size
+        self.assertEqual(gguf.Metadata.get_model_id_components("internlm/internlm2_5-7b-chat-1m", 7 * 10**9),
+                         ('internlm2_5-7b-chat-1m', 'internlm', 'internlm2_5', 'chat-1m', None, '7B'))
+
+        # Version before the finetune name
+        self.assertEqual(gguf.Metadata.get_model_id_components("pszemraj/jamba-900M-v0.13-KIx2"),
+                         ('jamba-900M-v0.13-KIx2', 'pszemraj', 'jamba', 'KIx2', 'v0.13', '900M'))
+
+        # TODO: hf suffix which could be ignored but isn't
+        self.assertEqual(gguf.Metadata.get_model_id_components("state-spaces/mamba-2.8b-hf"),
+                         ('mamba-2.8b-hf', 'state-spaces', 'mamba', 'hf', None, '2.8B'))
+
+        # Two sizes, don't merge them, the other is the number of tokens on which it was trained
+        self.assertEqual(gguf.Metadata.get_model_id_components("abacaj/llama-161M-100B", 161 * 10**6),
+                         ('llama-161M-100B', 'abacaj', 'llama', '100b', None, '161M'))
+
+        # It's a trap, there is no size label
+        self.assertEqual(gguf.Metadata.get_model_id_components("SparseLLM/relu-100B", 1340 * 10**6),
+                         ('relu-100B', 'SparseLLM', 'relu', '100b', None, None))
+
+        # Weird size notation
+        self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
+                         ('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
+
+    def test_apply_metadata_heuristic_from_model_card(self):
+        model_card = {
+            'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
+            'model-index': [{'name': 'Mixtral-8x7B-Instruct-v0.1', 'results': []}],
+            'language': ['en'],
+            'datasets': ['teknium/OpenHermes-2.5'],
+            'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}],
+            'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"]
+        }
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
+        expect = gguf.Metadata()
+        expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': 'v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
+        expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
+        expect.languages=['en']
+        expect.datasets=['teknium/OpenHermes-2.5']
+
+        self.assertEqual(got, expect)
+
+    def test_apply_metadata_heuristic_from_hf_parameters(self):
+        hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None)
+        expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
+        self.assertEqual(got, expect)
+
+    def test_apply_metadata_heuristic_from_model_dir(self):
+        model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path)
+        expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
+        self.assertEqual(got, expect)
+
+
+if __name__ == "__main__":
+    unittest.main()