|
|
@@ -17,6 +17,20 @@ logger = logging.getLogger("metadata")
|
|
|
|
|
|
@dataclass
|
|
|
class Metadata:
|
|
|
+ # Recommended Sampler Parameters to be written to GGUF KV Store
|
|
|
+ sampling_sequence: Optional[str] = None
|
|
|
+ sampling_top_k: Optional[int] = None
|
|
|
+ sampling_top_p: Optional[float] = None
|
|
|
+ sampling_min_p: Optional[float] = None
|
|
|
+ sampling_xtc_probability: Optional[float] = None
|
|
|
+ sampling_xtc_threshold: Optional[float] = None
|
|
|
+ sampling_temp: Optional[float] = None
|
|
|
+ sampling_penalty_last_n: Optional[int] = None
|
|
|
+ sampling_penalty_repeat: Optional[float] = None
|
|
|
+ sampling_mirostat: Optional[int] = None
|
|
|
+ sampling_mirostat_tau: Optional[float] = None
|
|
|
+ sampling_mirostat_eta: Optional[float] = None
|
|
|
+
|
|
|
# Authorship Metadata to be written to GGUF KV Store
|
|
|
name: Optional[str] = None
|
|
|
author: Optional[str] = None
|
|
|
@@ -54,15 +68,43 @@ class Metadata:
|
|
|
|
|
|
model_card = Metadata.load_model_card(model_path)
|
|
|
hf_params = Metadata.load_hf_parameters(model_path)
|
|
|
+ gen_config = Metadata.load_generation_config(model_path)
|
|
|
# TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
|
|
|
|
|
|
# heuristics
|
|
|
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
|
|
|
|
|
|
+ if gen_config:
|
|
|
+ metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence)
|
|
|
+ metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k)
|
|
|
+ metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p)
|
|
|
+ metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p)
|
|
|
+ metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability)
|
|
|
+ metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold)
|
|
|
+ metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp)
|
|
|
+ metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n)
|
|
|
+ metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat)
|
|
|
+ metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat)
|
|
|
+ metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau)
|
|
|
+ metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta)
|
|
|
+
|
|
|
# Metadata Override File Provided
|
|
|
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
|
|
metadata_override = Metadata.load_metadata_override(metadata_override_path)
|
|
|
|
|
|
+ metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence)
|
|
|
+ metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k)
|
|
|
+ metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p)
|
|
|
+ metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p)
|
|
|
+ metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability)
|
|
|
+ metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold)
|
|
|
+ metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp)
|
|
|
+ metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n)
|
|
|
+ metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat)
|
|
|
+ metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat)
|
|
|
+ metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau)
|
|
|
+ metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta)
|
|
|
+
|
|
|
metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
|
|
|
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
|
|
|
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
|
|
|
@@ -172,6 +214,23 @@ class Metadata:
|
|
|
with open(config_path, "r", encoding="utf-8") as f:
|
|
|
return json.load(f)
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]:
|
|
|
+ if model_path is None or not model_path.is_dir():
|
|
|
+ return {}
|
|
|
+
|
|
|
+ generation_config_path = model_path / "generation_config.json"
|
|
|
+
|
|
|
+ if not generation_config_path.is_file():
|
|
|
+ return {}
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(generation_config_path, "r", encoding="utf-8") as f:
|
|
|
+ return json.load(f)
|
|
|
+ except (json.JSONDecodeError, IOError):
|
|
|
+ # not all models have valid generation_config.json
|
|
|
+ return {}
|
|
|
+
|
|
|
@staticmethod
|
|
|
def id_to_title(string):
|
|
|
# Convert capitalization into title form unless acronym or version number
|
|
|
@@ -546,6 +605,32 @@ class Metadata:
|
|
|
|
|
|
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
|
|
|
assert self.name is not None
|
|
|
+
|
|
|
+ if self.sampling_sequence is not None:
|
|
|
+ gguf_writer.add_sampling_sequence(self.sampling_sequence)
|
|
|
+ if self.sampling_top_k is not None:
|
|
|
+ gguf_writer.add_sampling_top_k(self.sampling_top_k)
|
|
|
+ if self.sampling_top_p is not None:
|
|
|
+ gguf_writer.add_sampling_top_p(self.sampling_top_p)
|
|
|
+ if self.sampling_min_p is not None:
|
|
|
+ gguf_writer.add_sampling_min_p(self.sampling_min_p)
|
|
|
+ if self.sampling_xtc_probability is not None:
|
|
|
+ gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability)
|
|
|
+ if self.sampling_xtc_threshold is not None:
|
|
|
+ gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold)
|
|
|
+ if self.sampling_temp is not None:
|
|
|
+ gguf_writer.add_sampling_temp(self.sampling_temp)
|
|
|
+ if self.sampling_penalty_last_n is not None:
|
|
|
+ gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n)
|
|
|
+ if self.sampling_penalty_repeat is not None:
|
|
|
+ gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat)
|
|
|
+ if self.sampling_mirostat is not None:
|
|
|
+ gguf_writer.add_sampling_mirostat(self.sampling_mirostat)
|
|
|
+ if self.sampling_mirostat_tau is not None:
|
|
|
+ gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau)
|
|
|
+ if self.sampling_mirostat_eta is not None:
|
|
|
+ gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta)
|
|
|
+
|
|
|
gguf_writer.add_name(self.name)
|
|
|
|
|
|
if self.author is not None:
|