diff options
Diffstat (limited to 'gguf-py/gguf/metadata.py')
-rw-r--r-- | gguf-py/gguf/metadata.py | 503 |
1 files changed, 503 insertions, 0 deletions
diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py new file mode 100644 index 00000000..15189f71 --- /dev/null +++ b/gguf-py/gguf/metadata.py @@ -0,0 +1,503 @@ +from __future__ import annotations + +import re +import json +import yaml +import logging +from pathlib import Path +from typing import Any, Literal, Optional +from dataclasses import dataclass + +from .constants import Keys + +import gguf + +logger = logging.getLogger("metadata") + + +@dataclass +class Metadata: + # Authorship Metadata to be written to GGUF KV Store + name: Optional[str] = None + author: Optional[str] = None + version: Optional[str] = None + organization: Optional[str] = None + finetune: Optional[str] = None + basename: Optional[str] = None + description: Optional[str] = None + quantized_by: Optional[str] = None + size_label: Optional[str] = None + url: Optional[str] = None + doi: Optional[str] = None + uuid: Optional[str] = None + repo_url: Optional[str] = None + source_url: Optional[str] = None + source_doi: Optional[str] = None + source_uuid: Optional[str] = None + source_repo_url: Optional[str] = None + license: Optional[str] = None + license_name: Optional[str] = None + license_link: Optional[str] = None + base_models: Optional[list[dict]] = None + tags: Optional[list[str]] = None + languages: Optional[list[str]] = None + datasets: Optional[list[str]] = None + + @staticmethod + def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata: + # This grabs as many contextual authorship metadata as possible from the model repository + # making any conversion as required to match the gguf kv store metadata format + # as well as giving users the ability to override any authorship metadata that may be incorrect + + # Create a new Metadata instance + metadata = Metadata() + + model_card = Metadata.load_model_card(model_path) + hf_params = Metadata.load_hf_parameters(model_path) + # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter + + # heuristics + metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) + + # Metadata Override File Provided + # This is based on LLM_KV_NAMES mapping in llama.cpp + metadata_override = Metadata.load_metadata_override(metadata_override_path) + + metadata.name = metadata_override.get(Keys.General.NAME, metadata.name) + metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) + metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) + metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization) + + metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune) + metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename) + + metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description) + metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by) + + metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label) + metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name) + metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link) + + metadata.url = metadata_override.get(Keys.General.URL, metadata.url) + metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi) + metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid) + metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url) + + metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url) + metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi) + metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid) + metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url) + + # Base Models is received here as an array of models + metadata.base_models = metadata_override.get("general.base_models", metadata.base_models) + + metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags) + metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages) + metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets) + + # Direct Metadata Override (via direct cli argument) + if model_name is not None: + metadata.name = model_name + + return metadata + + @staticmethod + def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]: + if metadata_override_path is None or not metadata_override_path.is_file(): + return {} + + with open(metadata_override_path, "r", encoding="utf-8") as f: + return json.load(f) + + @staticmethod + def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]: + if model_path is None or not model_path.is_dir(): + return {} + + model_card_path = model_path / "README.md" + + if not model_card_path.is_file(): + return {} + + # The model card metadata is assumed to always be in YAML + # ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473 + with open(model_card_path, "r", encoding="utf-8") as f: + if f.readline() == "---\n": + raw = f.read().partition("---\n")[0] + data = yaml.safe_load(raw) + if isinstance(data, dict): + return data + else: + logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict") + return {} + else: + return {} + + @staticmethod + def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]: + if model_path is None or not model_path.is_dir(): + return {} + + config_path = model_path / "config.json" + + if not config_path.is_file(): + return {} + + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + + @staticmethod + def id_to_title(string): + # Convert capitalization into title form unless acronym or version number + return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()]) + + @staticmethod + def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]: + # Huggingface often store model id as '<org>/<model name>' + # so let's parse it and apply some heuristics if possible for model name components + + if model_id is None: + # model ID missing + return None, None, None, None, None, None + + if ' ' in model_id: + # model ID is actually a normal human sentence + # which means its most likely a normal model name only + # not part of the hugging face naming standard, but whatever + return model_id, None, None, None, None, None + + if '/' in model_id: + # model ID (huggingface style) + org_component, model_full_name_component = model_id.split('/', 1) + else: + # model ID but missing org components + org_component, model_full_name_component = None, model_id + + # Check if we erroneously matched against './' or '../' etc... + if org_component is not None and org_component[0] == '.': + org_component = None + + name_parts: list[str] = model_full_name_component.split('-') + + # Remove empty parts + for i in reversed(range(len(name_parts))): + if len(name_parts[i]) == 0: + del name_parts[i] + + name_types: list[ + set[Literal["basename", "size_label", "finetune", "version", "type"]] + ] = [set() for _ in name_parts] + + # Annotate the name + for i, part in enumerate(name_parts): + # Version + if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE): + name_types[i].add("version") + # Quant type (should not be there for base models, but still annotated) + elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE): + name_types[i].add("type") + name_parts[i] = part.upper() + # Model size + elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE): + part = part.replace("_", ".") + # Handle weird bloom-7b1 notation + if part[-1].isdecimal(): + part = part[:-2] + "." + part[-1] + part[-2] + # Normalize the size suffixes + if len(part) > 1 and part[-2].isdecimal(): + if part[-1] in "kmbt": + part = part[:-1] + part[-1].upper() + if total_params != 0: + try: + label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1])) + # Only use it as a size label if it's close or bigger than the model size + # Note that LoRA adapters don't necessarily include all layers, + # so this is why bigger label sizes are accepted. + # Do not use the size label when it's smaller than 1/8 of the model size + if (total_params < 0 and label_params < abs(total_params) // 8) or ( + # Check both directions when the current model isn't a LoRA adapter + total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8 + ): + # Likely a context length + name_types[i].add("finetune") + # Lowercase the size when it's a context length + part = part[:-1] + part[-1].lower() + except ValueError: + # Failed to convert the size label to float, use it anyway + pass + if len(name_types[i]) == 0: + name_types[i].add("size_label") + name_parts[i] = part + # Some easy to recognize finetune names + elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE): + if total_params < 0 and part.lower() == "lora": + # ignore redundant "lora" in the finetune part when the output is a lora adapter + name_types[i].add("type") + else: + name_types[i].add("finetune") + + # Ignore word-based size labels when there is at least a number-based one present + # TODO: should word-based size labels always be removed instead? + if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n): + for n, t in zip(name_parts, name_types): + if "size_label" in t: + if all(c.isalpha() for c in n): + t.remove("size_label") + + at_start = True + # Find the basename through the annotated name + for part, t in zip(name_parts, name_types): + if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t): + t.add("basename") + else: + if at_start: + at_start = False + if len(t) == 0: + t.add("finetune") + + # Remove the basename annotation from trailing version + for part, t in zip(reversed(name_parts), reversed(name_types)): + if "basename" in t and len(t) > 1: + t.remove("basename") + else: + break + + basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None + # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys) + size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None + finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None + # TODO: should the basename version always be excluded? + # NOTE: multiple finetune versions are joined together + version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None + + if size_label is None and finetune is None and version is None: + # Too ambiguous, output nothing + basename = None + + return model_full_name_component, org_component, basename, finetune, version, size_label + + @staticmethod + def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata: + # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 + + # Model Card Heuristics + ######################## + if model_card is not None: + + if "model_name" in model_card and metadata.name is None: + # Not part of huggingface model card standard but notice some model creator using it + # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF' + metadata.name = model_card.get("model_name") + + if "model_creator" in model_card and metadata.author is None: + # Not part of huggingface model card standard but notice some model creator using it + # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF' + metadata.author = model_card.get("model_creator") + + if "model_type" in model_card and metadata.basename is None: + # Not part of huggingface model card standard but notice some model creator using it + # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF' + metadata.basename = model_card.get("model_type") + + if "base_model" in model_card: + # This represents the parent models that this is based on + # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges) + # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md + metadata_base_models = [] + base_model_value = model_card.get("base_model", None) + + if base_model_value is not None: + if isinstance(base_model_value, str): + metadata_base_models.append(base_model_value) + elif isinstance(base_model_value, list): + metadata_base_models.extend(base_model_value) + + if metadata.base_models is None: + metadata.base_models = [] + + for model_id in metadata_base_models: + # NOTE: model size of base model is assumed to be similar to the size of the current model + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + base_model = {} + if model_full_name_component is not None: + base_model["name"] = Metadata.id_to_title(model_full_name_component) + if org_component is not None: + base_model["organization"] = Metadata.id_to_title(org_component) + if version is not None: + base_model["version"] = version + if org_component is not None and model_full_name_component is not None: + base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}" + metadata.base_models.append(base_model) + + if "license" in model_card and metadata.license is None: + metadata.license = model_card.get("license") + + if "license_name" in model_card and metadata.license_name is None: + metadata.license_name = model_card.get("license_name") + + if "license_link" in model_card and metadata.license_link is None: + metadata.license_link = model_card.get("license_link") + + tags_value = model_card.get("tags", None) + if tags_value is not None: + + if metadata.tags is None: + metadata.tags = [] + + if isinstance(tags_value, str): + metadata.tags.append(tags_value) + elif isinstance(tags_value, list): + metadata.tags.extend(tags_value) + + pipeline_tags_value = model_card.get("pipeline_tag", None) + if pipeline_tags_value is not None: + + if metadata.tags is None: + metadata.tags = [] + + if isinstance(pipeline_tags_value, str): + metadata.tags.append(pipeline_tags_value) + elif isinstance(pipeline_tags_value, list): + metadata.tags.extend(pipeline_tags_value) + + language_value = model_card.get("languages", model_card.get("language", None)) + if language_value is not None: + + if metadata.languages is None: + metadata.languages = [] + + if isinstance(language_value, str): + metadata.languages.append(language_value) + elif isinstance(language_value, list): + metadata.languages.extend(language_value) + + dataset_value = model_card.get("datasets", model_card.get("dataset", None)) + if dataset_value is not None: + + if metadata.datasets is None: + metadata.datasets = [] + + if isinstance(dataset_value, str): + metadata.datasets.append(dataset_value) + elif isinstance(dataset_value, list): + metadata.datasets.extend(dataset_value) + + # Hugging Face Parameter Heuristics + #################################### + + if hf_params is not None: + + hf_name_or_path = hf_params.get("_name_or_path") + if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1: + # Use _name_or_path only if its actually a model name and not some computer path + # e.g. 'meta-llama/Llama-2-7b-hf' + model_id = hf_name_or_path + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + if metadata.name is None and model_full_name_component is not None: + metadata.name = Metadata.id_to_title(model_full_name_component) + if metadata.organization is None and org_component is not None: + metadata.organization = Metadata.id_to_title(org_component) + if metadata.basename is None and basename is not None: + metadata.basename = basename + if metadata.finetune is None and finetune is not None: + metadata.finetune = finetune + if metadata.version is None and version is not None: + metadata.version = version + if metadata.size_label is None and size_label is not None: + metadata.size_label = size_label + + # Directory Folder Name Fallback Heuristics + ############################################ + if model_path is not None: + model_id = model_path.name + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + if metadata.name is None and model_full_name_component is not None: + metadata.name = Metadata.id_to_title(model_full_name_component) + if metadata.organization is None and org_component is not None: + metadata.organization = Metadata.id_to_title(org_component) + if metadata.basename is None and basename is not None: + metadata.basename = basename + if metadata.finetune is None and finetune is not None: + metadata.finetune = finetune + if metadata.version is None and version is not None: + metadata.version = version + if metadata.size_label is None and size_label is not None: + metadata.size_label = size_label + + return metadata + + def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): + assert self.name is not None + gguf_writer.add_name(self.name) + + if self.author is not None: + gguf_writer.add_author(self.author) + if self.version is not None: + gguf_writer.add_version(self.version) + if self.organization is not None: + gguf_writer.add_organization(self.organization) + + if self.finetune is not None: + gguf_writer.add_finetune(self.finetune) + if self.basename is not None: + gguf_writer.add_basename(self.basename) + + if self.description is not None: + gguf_writer.add_description(self.description) + if self.quantized_by is not None: + gguf_writer.add_quantized_by(self.quantized_by) + + if self.size_label is not None: + gguf_writer.add_size_label(self.size_label) + + if self.license is not None: + gguf_writer.add_license(self.license) + if self.license_name is not None: + gguf_writer.add_license_name(self.license_name) + if self.license_link is not None: + gguf_writer.add_license_link(self.license_link) + + if self.url is not None: + gguf_writer.add_url(self.url) + if self.doi is not None: + gguf_writer.add_doi(self.doi) + if self.uuid is not None: + gguf_writer.add_uuid(self.uuid) + if self.repo_url is not None: + gguf_writer.add_repo_url(self.repo_url) + + if self.source_url is not None: + gguf_writer.add_source_url(self.source_url) + if self.source_doi is not None: + gguf_writer.add_source_doi(self.source_doi) + if self.source_uuid is not None: + gguf_writer.add_source_uuid(self.source_uuid) + if self.source_repo_url is not None: + gguf_writer.add_source_repo_url(self.source_repo_url) + + if self.base_models is not None: + gguf_writer.add_base_model_count(len(self.base_models)) + for key, base_model_entry in enumerate(self.base_models): + if "name" in base_model_entry: + gguf_writer.add_base_model_name(key, base_model_entry["name"]) + if "author" in base_model_entry: + gguf_writer.add_base_model_author(key, base_model_entry["author"]) + if "version" in base_model_entry: + gguf_writer.add_base_model_version(key, base_model_entry["version"]) + if "organization" in base_model_entry: + gguf_writer.add_base_model_organization(key, base_model_entry["organization"]) + if "url" in base_model_entry: + gguf_writer.add_base_model_url(key, base_model_entry["url"]) + if "doi" in base_model_entry: + gguf_writer.add_base_model_doi(key, base_model_entry["doi"]) + if "uuid" in base_model_entry: + gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"]) + if "repo_url" in base_model_entry: + gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"]) + + if self.tags is not None: + gguf_writer.add_tags(self.tags) + if self.languages is not None: + gguf_writer.add_languages(self.languages) + if self.datasets is not None: + gguf_writer.add_datasets(self.datasets) |