|
@@ -41,7 +41,7 @@ class Metadata:
|
|
|
base_models: Optional[list[dict]] = None
|
|
base_models: Optional[list[dict]] = None
|
|
|
tags: Optional[list[str]] = None
|
|
tags: Optional[list[str]] = None
|
|
|
languages: Optional[list[str]] = None
|
|
languages: Optional[list[str]] = None
|
|
|
- datasets: Optional[list[str]] = None
|
|
|
|
|
|
|
+ datasets: Optional[list[dict]] = None
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
|
|
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
|
|
@@ -91,9 +91,11 @@ class Metadata:
|
|
|
# Base Models is received here as an array of models
|
|
# Base Models is received here as an array of models
|
|
|
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
|
|
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
|
|
|
|
|
|
|
|
|
|
+ # Datasets is received here as an array of datasets
|
|
|
|
|
+ metadata.datasets = metadata_override.get("general.datasets", metadata.datasets)
|
|
|
|
|
+
|
|
|
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
|
|
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
|
|
|
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
|
|
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
|
|
|
- metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
|
|
|
|
|
|
|
|
|
|
# Direct Metadata Override (via direct cli argument)
|
|
# Direct Metadata Override (via direct cli argument)
|
|
|
if model_name is not None:
|
|
if model_name is not None:
|
|
@@ -346,12 +348,12 @@ class Metadata:
|
|
|
use_model_card_metadata("author", "model_creator")
|
|
use_model_card_metadata("author", "model_creator")
|
|
|
use_model_card_metadata("basename", "model_type")
|
|
use_model_card_metadata("basename", "model_type")
|
|
|
|
|
|
|
|
- if "base_model" in model_card:
|
|
|
|
|
|
|
+ if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card:
|
|
|
# This represents the parent models that this is based on
|
|
# This represents the parent models that this is based on
|
|
|
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
|
|
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
|
|
|
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
|
|
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
|
|
|
metadata_base_models = []
|
|
metadata_base_models = []
|
|
|
- base_model_value = model_card.get("base_model", None)
|
|
|
|
|
|
|
+ base_model_value = model_card.get("base_model", model_card.get("base_models", model_card.get("base_model_sources", None)))
|
|
|
|
|
|
|
|
if base_model_value is not None:
|
|
if base_model_value is not None:
|
|
|
if isinstance(base_model_value, str):
|
|
if isinstance(base_model_value, str):
|
|
@@ -364,18 +366,106 @@ class Metadata:
|
|
|
|
|
|
|
|
for model_id in metadata_base_models:
|
|
for model_id in metadata_base_models:
|
|
|
# NOTE: model size of base model is assumed to be similar to the size of the current model
|
|
# NOTE: model size of base model is assumed to be similar to the size of the current model
|
|
|
- model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
|
|
|
|
base_model = {}
|
|
base_model = {}
|
|
|
- if model_full_name_component is not None:
|
|
|
|
|
- base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
|
|
|
|
- if org_component is not None:
|
|
|
|
|
- base_model["organization"] = Metadata.id_to_title(org_component)
|
|
|
|
|
- if version is not None:
|
|
|
|
|
- base_model["version"] = version
|
|
|
|
|
- if org_component is not None and model_full_name_component is not None:
|
|
|
|
|
- base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
|
|
|
|
|
|
+ if isinstance(model_id, str):
|
|
|
|
|
+ if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
|
|
|
|
|
+ base_model["repo_url"] = model_id
|
|
|
|
|
+
|
|
|
|
|
+ # Check if Hugging Face ID is present in URL
|
|
|
|
|
+ if "huggingface.co" in model_id:
|
|
|
|
|
+ match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id)
|
|
|
|
|
+ if match:
|
|
|
|
|
+ model_id_component = match.group(1)
|
|
|
|
|
+ model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params)
|
|
|
|
|
+
|
|
|
|
|
+ # Populate model dictionary with extracted components
|
|
|
|
|
+ if model_full_name_component is not None:
|
|
|
|
|
+ base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
|
|
|
|
+ if org_component is not None:
|
|
|
|
|
+ base_model["organization"] = Metadata.id_to_title(org_component)
|
|
|
|
|
+ if version is not None:
|
|
|
|
|
+ base_model["version"] = version
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Likely a Hugging Face ID
|
|
|
|
|
+ model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
|
|
|
|
+
|
|
|
|
|
+ # Populate model dictionary with extracted components
|
|
|
|
|
+ if model_full_name_component is not None:
|
|
|
|
|
+ base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
|
|
|
|
+ if org_component is not None:
|
|
|
|
|
+ base_model["organization"] = Metadata.id_to_title(org_component)
|
|
|
|
|
+ if version is not None:
|
|
|
|
|
+ base_model["version"] = version
|
|
|
|
|
+ if org_component is not None and model_full_name_component is not None:
|
|
|
|
|
+ base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
|
|
|
|
+
|
|
|
|
|
+ elif isinstance(model_id, dict):
|
|
|
|
|
+ base_model = model_id
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ logger.error(f"base model entry '{str(model_id)}' not in a known format")
|
|
|
|
|
+
|
|
|
metadata.base_models.append(base_model)
|
|
metadata.base_models.append(base_model)
|
|
|
|
|
|
|
|
|
|
+ if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card:
|
|
|
|
|
+ # This represents the datasets that this was trained from
|
|
|
|
|
+ metadata_datasets = []
|
|
|
|
|
+ dataset_value = model_card.get("datasets", model_card.get("dataset", model_card.get("dataset_sources", None)))
|
|
|
|
|
+
|
|
|
|
|
+ if dataset_value is not None:
|
|
|
|
|
+ if isinstance(dataset_value, str):
|
|
|
|
|
+ metadata_datasets.append(dataset_value)
|
|
|
|
|
+ elif isinstance(dataset_value, list):
|
|
|
|
|
+ metadata_datasets.extend(dataset_value)
|
|
|
|
|
+
|
|
|
|
|
+ if metadata.datasets is None:
|
|
|
|
|
+ metadata.datasets = []
|
|
|
|
|
+
|
|
|
|
|
+ for dataset_id in metadata_datasets:
|
|
|
|
|
+ # NOTE: model size of base model is assumed to be similar to the size of the current model
|
|
|
|
|
+ dataset = {}
|
|
|
|
|
+ if isinstance(dataset_id, str):
|
|
|
|
|
+ if dataset_id.startswith(("http://", "https://", "ssh://")):
|
|
|
|
|
+ dataset["repo_url"] = dataset_id
|
|
|
|
|
+
|
|
|
|
|
+ # Check if Hugging Face ID is present in URL
|
|
|
|
|
+ if "huggingface.co" in dataset_id:
|
|
|
|
|
+ match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id)
|
|
|
|
|
+ if match:
|
|
|
|
|
+ dataset_id_component = match.group(1)
|
|
|
|
|
+ dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params)
|
|
|
|
|
+
|
|
|
|
|
+ # Populate dataset dictionary with extracted components
|
|
|
|
|
+ if dataset_name_component is not None:
|
|
|
|
|
+ dataset["name"] = Metadata.id_to_title(dataset_name_component)
|
|
|
|
|
+ if org_component is not None:
|
|
|
|
|
+ dataset["organization"] = Metadata.id_to_title(org_component)
|
|
|
|
|
+ if version is not None:
|
|
|
|
|
+ dataset["version"] = version
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Likely a Hugging Face ID
|
|
|
|
|
+ dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
|
|
|
|
|
+
|
|
|
|
|
+ # Populate dataset dictionary with extracted components
|
|
|
|
|
+ if dataset_name_component is not None:
|
|
|
|
|
+ dataset["name"] = Metadata.id_to_title(dataset_name_component)
|
|
|
|
|
+ if org_component is not None:
|
|
|
|
|
+ dataset["organization"] = Metadata.id_to_title(org_component)
|
|
|
|
|
+ if version is not None:
|
|
|
|
|
+ dataset["version"] = version
|
|
|
|
|
+ if org_component is not None and dataset_name_component is not None:
|
|
|
|
|
+ dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
|
|
|
|
|
+
|
|
|
|
|
+ elif isinstance(dataset_id, dict):
|
|
|
|
|
+ dataset = dataset_id
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ logger.error(f"dataset entry '{str(dataset_id)}' not in a known format")
|
|
|
|
|
+
|
|
|
|
|
+ metadata.datasets.append(dataset)
|
|
|
|
|
+
|
|
|
use_model_card_metadata("license", "license")
|
|
use_model_card_metadata("license", "license")
|
|
|
use_model_card_metadata("license_name", "license_name")
|
|
use_model_card_metadata("license_name", "license_name")
|
|
|
use_model_card_metadata("license_link", "license_link")
|
|
use_model_card_metadata("license_link", "license_link")
|
|
@@ -386,9 +476,6 @@ class Metadata:
|
|
|
use_array_model_card_metadata("languages", "languages")
|
|
use_array_model_card_metadata("languages", "languages")
|
|
|
use_array_model_card_metadata("languages", "language")
|
|
use_array_model_card_metadata("languages", "language")
|
|
|
|
|
|
|
|
- use_array_model_card_metadata("datasets", "datasets")
|
|
|
|
|
- use_array_model_card_metadata("datasets", "dataset")
|
|
|
|
|
-
|
|
|
|
|
# Hugging Face Parameter Heuristics
|
|
# Hugging Face Parameter Heuristics
|
|
|
####################################
|
|
####################################
|
|
|
|
|
|
|
@@ -493,6 +580,8 @@ class Metadata:
|
|
|
gguf_writer.add_base_model_version(key, base_model_entry["version"])
|
|
gguf_writer.add_base_model_version(key, base_model_entry["version"])
|
|
|
if "organization" in base_model_entry:
|
|
if "organization" in base_model_entry:
|
|
|
gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
|
|
gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
|
|
|
|
|
+ if "description" in base_model_entry:
|
|
|
|
|
+ gguf_writer.add_base_model_description(key, base_model_entry["description"])
|
|
|
if "url" in base_model_entry:
|
|
if "url" in base_model_entry:
|
|
|
gguf_writer.add_base_model_url(key, base_model_entry["url"])
|
|
gguf_writer.add_base_model_url(key, base_model_entry["url"])
|
|
|
if "doi" in base_model_entry:
|
|
if "doi" in base_model_entry:
|
|
@@ -502,9 +591,29 @@ class Metadata:
|
|
|
if "repo_url" in base_model_entry:
|
|
if "repo_url" in base_model_entry:
|
|
|
gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
|
|
gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
|
|
|
|
|
|
|
|
|
|
+ if self.datasets is not None:
|
|
|
|
|
+ gguf_writer.add_dataset_count(len(self.datasets))
|
|
|
|
|
+ for key, dataset_entry in enumerate(self.datasets):
|
|
|
|
|
+ if "name" in dataset_entry:
|
|
|
|
|
+ gguf_writer.add_dataset_name(key, dataset_entry["name"])
|
|
|
|
|
+ if "author" in dataset_entry:
|
|
|
|
|
+ gguf_writer.add_dataset_author(key, dataset_entry["author"])
|
|
|
|
|
+ if "version" in dataset_entry:
|
|
|
|
|
+ gguf_writer.add_dataset_version(key, dataset_entry["version"])
|
|
|
|
|
+ if "organization" in dataset_entry:
|
|
|
|
|
+ gguf_writer.add_dataset_organization(key, dataset_entry["organization"])
|
|
|
|
|
+ if "description" in dataset_entry:
|
|
|
|
|
+ gguf_writer.add_dataset_description(key, dataset_entry["description"])
|
|
|
|
|
+ if "url" in dataset_entry:
|
|
|
|
|
+ gguf_writer.add_dataset_url(key, dataset_entry["url"])
|
|
|
|
|
+ if "doi" in dataset_entry:
|
|
|
|
|
+ gguf_writer.add_dataset_doi(key, dataset_entry["doi"])
|
|
|
|
|
+ if "uuid" in dataset_entry:
|
|
|
|
|
+ gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"])
|
|
|
|
|
+ if "repo_url" in dataset_entry:
|
|
|
|
|
+ gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"])
|
|
|
|
|
+
|
|
|
if self.tags is not None:
|
|
if self.tags is not None:
|
|
|
gguf_writer.add_tags(self.tags)
|
|
gguf_writer.add_tags(self.tags)
|
|
|
if self.languages is not None:
|
|
if self.languages is not None:
|
|
|
gguf_writer.add_languages(self.languages)
|
|
gguf_writer.add_languages(self.languages)
|
|
|
- if self.datasets is not None:
|
|
|
|
|
- gguf_writer.add_datasets(self.datasets)
|
|
|