metadata.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731
  1. from __future__ import annotations
  2. import re
  3. import json
  4. import yaml
  5. import logging
  6. from pathlib import Path
  7. from typing import Any, Literal, Optional
  8. from dataclasses import dataclass
  9. from .constants import Keys
  10. import gguf
  11. logger = logging.getLogger("metadata")
  12. @dataclass
  13. class Metadata:
  14. # Recommended Sampler Parameters to be written to GGUF KV Store
  15. sampling_sequence: Optional[str] = None
  16. sampling_top_k: Optional[int] = None
  17. sampling_top_p: Optional[float] = None
  18. sampling_min_p: Optional[float] = None
  19. sampling_xtc_probability: Optional[float] = None
  20. sampling_xtc_threshold: Optional[float] = None
  21. sampling_temp: Optional[float] = None
  22. sampling_penalty_last_n: Optional[int] = None
  23. sampling_penalty_repeat: Optional[float] = None
  24. sampling_mirostat: Optional[int] = None
  25. sampling_mirostat_tau: Optional[float] = None
  26. sampling_mirostat_eta: Optional[float] = None
  27. # Authorship Metadata to be written to GGUF KV Store
  28. name: Optional[str] = None
  29. author: Optional[str] = None
  30. version: Optional[str] = None
  31. organization: Optional[str] = None
  32. finetune: Optional[str] = None
  33. basename: Optional[str] = None
  34. description: Optional[str] = None
  35. quantized_by: Optional[str] = None
  36. size_label: Optional[str] = None
  37. url: Optional[str] = None
  38. doi: Optional[str] = None
  39. uuid: Optional[str] = None
  40. repo_url: Optional[str] = None
  41. source_url: Optional[str] = None
  42. source_doi: Optional[str] = None
  43. source_uuid: Optional[str] = None
  44. source_repo_url: Optional[str] = None
  45. license: Optional[str] = None
  46. license_name: Optional[str] = None
  47. license_link: Optional[str] = None
  48. base_models: Optional[list[dict]] = None
  49. tags: Optional[list[str]] = None
  50. languages: Optional[list[str]] = None
  51. datasets: Optional[list[dict]] = None
  52. @staticmethod
  53. def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
  54. # This grabs as many contextual authorship metadata as possible from the model repository
  55. # making any conversion as required to match the gguf kv store metadata format
  56. # as well as giving users the ability to override any authorship metadata that may be incorrect
  57. # Create a new Metadata instance
  58. metadata = Metadata()
  59. model_card = Metadata.load_model_card(model_path)
  60. hf_params = Metadata.load_hf_parameters(model_path)
  61. gen_config = Metadata.load_generation_config(model_path)
  62. # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
  63. # heuristics
  64. metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
  65. if gen_config:
  66. metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence)
  67. metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k)
  68. metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p)
  69. metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p)
  70. metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability)
  71. metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold)
  72. metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp)
  73. metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n)
  74. metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat)
  75. metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat)
  76. metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau)
  77. metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta)
  78. # Metadata Override File Provided
  79. # This is based on LLM_KV_NAMES mapping in llama.cpp
  80. metadata_override = Metadata.load_metadata_override(metadata_override_path)
  81. metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence)
  82. metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k)
  83. metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p)
  84. metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p)
  85. metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability)
  86. metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold)
  87. metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp)
  88. metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n)
  89. metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat)
  90. metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat)
  91. metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau)
  92. metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta)
  93. metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
  94. metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
  95. metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
  96. metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
  97. metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
  98. metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
  99. metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
  100. metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
  101. metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label)
  102. metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
  103. metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
  104. metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
  105. metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
  106. metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
  107. metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
  108. metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
  109. metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
  110. metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
  111. metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
  112. # Base Models is received here as an array of models
  113. metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
  114. # Datasets is received here as an array of datasets
  115. metadata.datasets = metadata_override.get("general.datasets", metadata.datasets)
  116. metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
  117. metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
  118. # Direct Metadata Override (via direct cli argument)
  119. if model_name is not None:
  120. metadata.name = model_name
  121. return metadata
  122. @staticmethod
  123. def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
  124. if metadata_override_path is None or not metadata_override_path.is_file():
  125. return {}
  126. with open(metadata_override_path, "r", encoding="utf-8") as f:
  127. return json.load(f)
  128. @staticmethod
  129. def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
  130. if model_path is None or not model_path.is_dir():
  131. return {}
  132. model_card_path = model_path / "README.md"
  133. if not model_card_path.is_file():
  134. return {}
  135. # The model card metadata is assumed to always be in YAML (frontmatter)
  136. # ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
  137. yaml_content: str = ""
  138. with open(model_card_path, "r", encoding="utf-8") as f:
  139. content = f.read()
  140. lines = content.splitlines()
  141. lines_yaml = []
  142. if len(lines) == 0:
  143. # Empty file
  144. return {}
  145. if len(lines) > 0 and lines[0] != "---":
  146. # No frontmatter
  147. return {}
  148. for line in lines[1:]:
  149. if line == "---":
  150. break # End of frontmatter
  151. else:
  152. lines_yaml.append(line)
  153. yaml_content = "\n".join(lines_yaml) + "\n"
  154. # Quick hack to fix the Norway problem
  155. # https://hitchdev.com/strictyaml/why/implicit-typing-removed/
  156. yaml_content = yaml_content.replace("- no\n", "- \"no\"\n")
  157. # yaml should use 2 spaces insted of tab
  158. # this issue has came up with the Qwen/Qwen3-235B-A22B-Instruct-2507 model card
  159. # (I've also sent a pr tp fix the modelcard too)
  160. yaml_content = yaml_content.replace("\t", " ")
  161. if yaml_content:
  162. data = yaml.safe_load(yaml_content)
  163. if isinstance(data, dict):
  164. return data
  165. else:
  166. logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
  167. return {}
  168. else:
  169. return {}
  170. @staticmethod
  171. def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
  172. if model_path is None or not model_path.is_dir():
  173. return {}
  174. config_path = model_path / "config.json"
  175. if not config_path.is_file():
  176. return {}
  177. with open(config_path, "r", encoding="utf-8") as f:
  178. return json.load(f)
  179. @staticmethod
  180. def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]:
  181. if model_path is None or not model_path.is_dir():
  182. return {}
  183. generation_config_path = model_path / "generation_config.json"
  184. if not generation_config_path.is_file():
  185. return {}
  186. try:
  187. with open(generation_config_path, "r", encoding="utf-8") as f:
  188. return json.load(f)
  189. except (json.JSONDecodeError, IOError):
  190. # not all models have valid generation_config.json
  191. return {}
  192. @staticmethod
  193. def id_to_title(string):
  194. # Convert capitalization into title form unless acronym or version number
  195. return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
  196. @staticmethod
  197. def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
  198. # Huggingface often store model id as '<org>/<model name>'
  199. # so let's parse it and apply some heuristics if possible for model name components
  200. if model_id is None:
  201. # model ID missing
  202. return None, None, None, None, None, None
  203. if ' ' in model_id:
  204. # model ID is actually a normal human sentence
  205. # which means its most likely a normal model name only
  206. # not part of the hugging face naming standard, but whatever
  207. return model_id, None, None, None, None, None
  208. if '/' in model_id:
  209. # model ID (huggingface style)
  210. org_component, model_full_name_component = model_id.split('/', 1)
  211. else:
  212. # model ID but missing org components
  213. org_component, model_full_name_component = None, model_id
  214. # Check if we erroneously matched against './' or '../' etc...
  215. if org_component is not None and len(org_component) > 0 and org_component[0] == '.':
  216. org_component = None
  217. name_parts: list[str] = model_full_name_component.split('-')
  218. # Remove empty parts
  219. for i in reversed(range(len(name_parts))):
  220. if len(name_parts[i]) == 0:
  221. del name_parts[i]
  222. name_types: list[
  223. set[Literal["basename", "size_label", "finetune", "version", "type"]]
  224. ] = [set() for _ in name_parts]
  225. # Annotate the name
  226. for i, part in enumerate(name_parts):
  227. # Version
  228. if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
  229. name_types[i].add("version")
  230. # Quant type (should not be there for base models, but still annotated)
  231. elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
  232. name_types[i].add("type")
  233. name_parts[i] = part.upper()
  234. # Model size
  235. elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
  236. part = part.replace("_", ".")
  237. # Handle weird bloom-7b1 notation
  238. if part[-1].isdecimal():
  239. part = part[:-2] + "." + part[-1] + part[-2]
  240. # Normalize the size suffixes
  241. if len(part) > 1 and part[-2].isdecimal():
  242. if part[-1] in "kmbt":
  243. part = part[:-1] + part[-1].upper()
  244. if total_params != 0:
  245. try:
  246. label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
  247. # Only use it as a size label if it's close or bigger than the model size
  248. # Note that LoRA adapters don't necessarily include all layers,
  249. # so this is why bigger label sizes are accepted.
  250. # Do not use the size label when it's smaller than 1/8 of the model size
  251. if (total_params < 0 and label_params < abs(total_params) // 8) or (
  252. # Check both directions when the current model isn't a LoRA adapter
  253. total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
  254. ):
  255. # Likely a context length
  256. name_types[i].add("finetune")
  257. # Lowercase the size when it's a context length
  258. part = part[:-1] + part[-1].lower()
  259. except ValueError:
  260. # Failed to convert the size label to float, use it anyway
  261. pass
  262. if len(name_types[i]) == 0:
  263. name_types[i].add("size_label")
  264. name_parts[i] = part
  265. # Some easy to recognize finetune names
  266. elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
  267. if total_params < 0 and part.lower() == "lora":
  268. # ignore redundant "lora" in the finetune part when the output is a lora adapter
  269. name_types[i].add("type")
  270. else:
  271. name_types[i].add("finetune")
  272. # Ignore word-based size labels when there is at least a number-based one present
  273. # TODO: should word-based size labels always be removed instead?
  274. if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
  275. for n, t in zip(name_parts, name_types):
  276. if "size_label" in t:
  277. if all(c.isalpha() for c in n):
  278. t.remove("size_label")
  279. at_start = True
  280. # Find the basename through the annotated name
  281. for part, t in zip(name_parts, name_types):
  282. if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
  283. t.add("basename")
  284. else:
  285. if at_start:
  286. at_start = False
  287. if len(t) == 0:
  288. t.add("finetune")
  289. # Remove the basename annotation from trailing version
  290. for part, t in zip(reversed(name_parts), reversed(name_types)):
  291. if "basename" in t and len(t) > 1:
  292. t.remove("basename")
  293. else:
  294. break
  295. basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
  296. # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
  297. size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
  298. finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
  299. # TODO: should the basename version always be excluded?
  300. # NOTE: multiple finetune versions are joined together
  301. version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None
  302. if size_label is None and finetune is None and version is None:
  303. # Too ambiguous, output nothing
  304. basename = None
  305. return model_full_name_component, org_component, basename, finetune, version, size_label
  306. @staticmethod
  307. def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
  308. # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
  309. # Model Card Heuristics
  310. ########################
  311. if model_card is not None:
  312. def use_model_card_metadata(metadata_key: str, model_card_key: str):
  313. if model_card_key in model_card and getattr(metadata, metadata_key, None) is None:
  314. setattr(metadata, metadata_key, model_card.get(model_card_key))
  315. def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
  316. # Note: Will append rather than replace if already exist
  317. tags_value = model_card.get(model_card_key, None)
  318. if tags_value is None:
  319. return
  320. current_value = getattr(metadata, metadata_key, None)
  321. if current_value is None:
  322. current_value = []
  323. if isinstance(tags_value, str):
  324. current_value.append(tags_value)
  325. elif isinstance(tags_value, list):
  326. current_value.extend(tags_value)
  327. setattr(metadata, metadata_key, current_value)
  328. # LLAMA.cpp's direct internal convention
  329. # (Definitely not part of hugging face formal/informal standard)
  330. #########################################
  331. use_model_card_metadata("name", "name")
  332. use_model_card_metadata("author", "author")
  333. use_model_card_metadata("version", "version")
  334. use_model_card_metadata("organization", "organization")
  335. use_model_card_metadata("description", "description")
  336. use_model_card_metadata("finetune", "finetune")
  337. use_model_card_metadata("basename", "basename")
  338. use_model_card_metadata("size_label", "size_label")
  339. use_model_card_metadata("source_url", "url")
  340. use_model_card_metadata("source_doi", "doi")
  341. use_model_card_metadata("source_uuid", "uuid")
  342. use_model_card_metadata("source_repo_url", "repo_url")
  343. # LLAMA.cpp's huggingface style convention
  344. # (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style)
  345. ###########################################
  346. use_model_card_metadata("name", "model_name")
  347. use_model_card_metadata("author", "model_author")
  348. use_model_card_metadata("version", "model_version")
  349. use_model_card_metadata("organization", "model_organization")
  350. use_model_card_metadata("description", "model_description")
  351. use_model_card_metadata("finetune", "model_finetune")
  352. use_model_card_metadata("basename", "model_basename")
  353. use_model_card_metadata("size_label", "model_size_label")
  354. use_model_card_metadata("source_url", "model_url")
  355. use_model_card_metadata("source_doi", "model_doi")
  356. use_model_card_metadata("source_uuid", "model_uuid")
  357. use_model_card_metadata("source_repo_url", "model_repo_url")
  358. # Hugging Face Direct Convention
  359. #################################
  360. # Not part of huggingface model card standard but notice some model creator using it
  361. # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
  362. use_model_card_metadata("name", "model_name")
  363. use_model_card_metadata("author", "model_creator")
  364. use_model_card_metadata("basename", "model_type")
  365. if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card:
  366. # This represents the parent models that this is based on
  367. # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
  368. # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
  369. metadata_base_models = []
  370. base_model_value = model_card.get("base_model", model_card.get("base_models", model_card.get("base_model_sources", None)))
  371. if base_model_value is not None:
  372. if isinstance(base_model_value, str):
  373. metadata_base_models.append(base_model_value)
  374. elif isinstance(base_model_value, list):
  375. metadata_base_models.extend(base_model_value)
  376. if metadata.base_models is None:
  377. metadata.base_models = []
  378. for model_id in metadata_base_models:
  379. # NOTE: model size of base model is assumed to be similar to the size of the current model
  380. base_model = {}
  381. if isinstance(model_id, str):
  382. if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
  383. base_model["repo_url"] = model_id
  384. # Check if Hugging Face ID is present in URL
  385. if "huggingface.co" in model_id:
  386. match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id)
  387. if match:
  388. model_id_component = match.group(1)
  389. model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params)
  390. # Populate model dictionary with extracted components
  391. if model_full_name_component is not None:
  392. base_model["name"] = Metadata.id_to_title(model_full_name_component)
  393. if org_component is not None:
  394. base_model["organization"] = Metadata.id_to_title(org_component)
  395. if version is not None:
  396. base_model["version"] = version
  397. else:
  398. # Likely a Hugging Face ID
  399. model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
  400. # Populate model dictionary with extracted components
  401. if model_full_name_component is not None:
  402. base_model["name"] = Metadata.id_to_title(model_full_name_component)
  403. if org_component is not None:
  404. base_model["organization"] = Metadata.id_to_title(org_component)
  405. if version is not None:
  406. base_model["version"] = version
  407. if org_component is not None and model_full_name_component is not None:
  408. base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
  409. elif isinstance(model_id, dict):
  410. base_model = model_id
  411. else:
  412. logger.error(f"base model entry '{str(model_id)}' not in a known format")
  413. metadata.base_models.append(base_model)
  414. if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card:
  415. # This represents the datasets that this was trained from
  416. metadata_datasets = []
  417. dataset_value = model_card.get("datasets", model_card.get("dataset", model_card.get("dataset_sources", None)))
  418. if dataset_value is not None:
  419. if isinstance(dataset_value, str):
  420. metadata_datasets.append(dataset_value)
  421. elif isinstance(dataset_value, list):
  422. metadata_datasets.extend(dataset_value)
  423. if metadata.datasets is None:
  424. metadata.datasets = []
  425. for dataset_id in metadata_datasets:
  426. # NOTE: model size of base model is assumed to be similar to the size of the current model
  427. dataset = {}
  428. if isinstance(dataset_id, str):
  429. if dataset_id.startswith(("http://", "https://", "ssh://")):
  430. dataset["repo_url"] = dataset_id
  431. # Check if Hugging Face ID is present in URL
  432. if "huggingface.co" in dataset_id:
  433. match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id)
  434. if match:
  435. dataset_id_component = match.group(1)
  436. dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params)
  437. # Populate dataset dictionary with extracted components
  438. if dataset_name_component is not None:
  439. dataset["name"] = Metadata.id_to_title(dataset_name_component)
  440. if org_component is not None:
  441. dataset["organization"] = Metadata.id_to_title(org_component)
  442. if version is not None:
  443. dataset["version"] = version
  444. else:
  445. # Likely a Hugging Face ID
  446. dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
  447. # Populate dataset dictionary with extracted components
  448. if dataset_name_component is not None:
  449. dataset["name"] = Metadata.id_to_title(dataset_name_component)
  450. if org_component is not None:
  451. dataset["organization"] = Metadata.id_to_title(org_component)
  452. if version is not None:
  453. dataset["version"] = version
  454. if org_component is not None and dataset_name_component is not None:
  455. dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
  456. elif isinstance(dataset_id, dict):
  457. dataset = dataset_id
  458. else:
  459. logger.error(f"dataset entry '{str(dataset_id)}' not in a known format")
  460. metadata.datasets.append(dataset)
  461. use_model_card_metadata("license", "license")
  462. use_model_card_metadata("license_name", "license_name")
  463. use_model_card_metadata("license_link", "license_link")
  464. use_array_model_card_metadata("tags", "tags")
  465. use_array_model_card_metadata("tags", "pipeline_tag")
  466. use_array_model_card_metadata("languages", "languages")
  467. use_array_model_card_metadata("languages", "language")
  468. # Hugging Face Parameter Heuristics
  469. ####################################
  470. if hf_params is not None:
  471. hf_name_or_path = hf_params.get("_name_or_path")
  472. if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
  473. # Use _name_or_path only if its actually a model name and not some computer path
  474. # e.g. 'meta-llama/Llama-2-7b-hf'
  475. model_id = hf_name_or_path
  476. model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
  477. if metadata.name is None and model_full_name_component is not None:
  478. metadata.name = Metadata.id_to_title(model_full_name_component)
  479. if metadata.organization is None and org_component is not None:
  480. metadata.organization = Metadata.id_to_title(org_component)
  481. if metadata.basename is None and basename is not None:
  482. metadata.basename = basename
  483. if metadata.finetune is None and finetune is not None:
  484. metadata.finetune = finetune
  485. if metadata.version is None and version is not None:
  486. metadata.version = version
  487. if metadata.size_label is None and size_label is not None:
  488. metadata.size_label = size_label
  489. # Directory Folder Name Fallback Heuristics
  490. ############################################
  491. if model_path is not None:
  492. model_id = model_path.name
  493. model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
  494. if metadata.name is None and model_full_name_component is not None:
  495. metadata.name = Metadata.id_to_title(model_full_name_component)
  496. if metadata.organization is None and org_component is not None:
  497. metadata.organization = Metadata.id_to_title(org_component)
  498. if metadata.basename is None and basename is not None:
  499. metadata.basename = basename
  500. if metadata.finetune is None and finetune is not None:
  501. metadata.finetune = finetune
  502. if metadata.version is None and version is not None:
  503. metadata.version = version
  504. if metadata.size_label is None and size_label is not None:
  505. metadata.size_label = size_label
  506. return metadata
  507. def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
  508. assert self.name is not None
  509. if self.sampling_sequence is not None:
  510. gguf_writer.add_sampling_sequence(self.sampling_sequence)
  511. if self.sampling_top_k is not None:
  512. gguf_writer.add_sampling_top_k(self.sampling_top_k)
  513. if self.sampling_top_p is not None:
  514. gguf_writer.add_sampling_top_p(self.sampling_top_p)
  515. if self.sampling_min_p is not None:
  516. gguf_writer.add_sampling_min_p(self.sampling_min_p)
  517. if self.sampling_xtc_probability is not None:
  518. gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability)
  519. if self.sampling_xtc_threshold is not None:
  520. gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold)
  521. if self.sampling_temp is not None:
  522. gguf_writer.add_sampling_temp(self.sampling_temp)
  523. if self.sampling_penalty_last_n is not None:
  524. gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n)
  525. if self.sampling_penalty_repeat is not None:
  526. gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat)
  527. if self.sampling_mirostat is not None:
  528. gguf_writer.add_sampling_mirostat(self.sampling_mirostat)
  529. if self.sampling_mirostat_tau is not None:
  530. gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau)
  531. if self.sampling_mirostat_eta is not None:
  532. gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta)
  533. gguf_writer.add_name(self.name)
  534. if self.author is not None:
  535. gguf_writer.add_author(self.author)
  536. if self.version is not None:
  537. gguf_writer.add_version(self.version)
  538. if self.organization is not None:
  539. gguf_writer.add_organization(self.organization)
  540. if self.finetune is not None:
  541. gguf_writer.add_finetune(self.finetune)
  542. if self.basename is not None:
  543. gguf_writer.add_basename(self.basename)
  544. if self.description is not None:
  545. gguf_writer.add_description(self.description)
  546. if self.quantized_by is not None:
  547. gguf_writer.add_quantized_by(self.quantized_by)
  548. if self.size_label is not None:
  549. gguf_writer.add_size_label(self.size_label)
  550. if self.license is not None:
  551. if isinstance(self.license, list):
  552. gguf_writer.add_license(",".join(self.license))
  553. else:
  554. gguf_writer.add_license(self.license)
  555. if self.license_name is not None:
  556. gguf_writer.add_license_name(self.license_name)
  557. if self.license_link is not None:
  558. gguf_writer.add_license_link(self.license_link)
  559. if self.url is not None:
  560. gguf_writer.add_url(self.url)
  561. if self.doi is not None:
  562. gguf_writer.add_doi(self.doi)
  563. if self.uuid is not None:
  564. gguf_writer.add_uuid(self.uuid)
  565. if self.repo_url is not None:
  566. gguf_writer.add_repo_url(self.repo_url)
  567. if self.source_url is not None:
  568. gguf_writer.add_source_url(self.source_url)
  569. if self.source_doi is not None:
  570. gguf_writer.add_source_doi(self.source_doi)
  571. if self.source_uuid is not None:
  572. gguf_writer.add_source_uuid(self.source_uuid)
  573. if self.source_repo_url is not None:
  574. gguf_writer.add_source_repo_url(self.source_repo_url)
  575. if self.base_models is not None:
  576. gguf_writer.add_base_model_count(len(self.base_models))
  577. for key, base_model_entry in enumerate(self.base_models):
  578. if "name" in base_model_entry:
  579. gguf_writer.add_base_model_name(key, base_model_entry["name"])
  580. if "author" in base_model_entry:
  581. gguf_writer.add_base_model_author(key, base_model_entry["author"])
  582. if "version" in base_model_entry:
  583. gguf_writer.add_base_model_version(key, base_model_entry["version"])
  584. if "organization" in base_model_entry:
  585. gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
  586. if "description" in base_model_entry:
  587. gguf_writer.add_base_model_description(key, base_model_entry["description"])
  588. if "url" in base_model_entry:
  589. gguf_writer.add_base_model_url(key, base_model_entry["url"])
  590. if "doi" in base_model_entry:
  591. gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
  592. if "uuid" in base_model_entry:
  593. gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
  594. if "repo_url" in base_model_entry:
  595. gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
  596. if self.datasets is not None:
  597. gguf_writer.add_dataset_count(len(self.datasets))
  598. for key, dataset_entry in enumerate(self.datasets):
  599. if "name" in dataset_entry:
  600. gguf_writer.add_dataset_name(key, dataset_entry["name"])
  601. if "author" in dataset_entry:
  602. gguf_writer.add_dataset_author(key, dataset_entry["author"])
  603. if "version" in dataset_entry:
  604. gguf_writer.add_dataset_version(key, dataset_entry["version"])
  605. if "organization" in dataset_entry:
  606. gguf_writer.add_dataset_organization(key, dataset_entry["organization"])
  607. if "description" in dataset_entry:
  608. gguf_writer.add_dataset_description(key, dataset_entry["description"])
  609. if "url" in dataset_entry:
  610. gguf_writer.add_dataset_url(key, dataset_entry["url"])
  611. if "doi" in dataset_entry:
  612. gguf_writer.add_dataset_doi(key, dataset_entry["doi"])
  613. if "uuid" in dataset_entry:
  614. gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"])
  615. if "repo_url" in dataset_entry:
  616. gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"])
  617. if self.tags is not None:
  618. gguf_writer.add_tags(self.tags)
  619. if self.languages is not None:
  620. gguf_writer.add_languages(self.languages)