diff options
Diffstat (limited to 'gguf-py')
-rw-r--r-- | gguf-py/gguf/gguf.py | 571 | ||||
-rw-r--r-- | gguf-py/gguf/py.typed | 0 | ||||
-rw-r--r-- | gguf-py/pyproject.toml | 1 |
3 files changed, 341 insertions, 231 deletions
diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index 838a2c0f..de3edbc9 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -4,9 +4,13 @@ import sys import struct import tempfile import numpy as np +import json +import os +from pathlib import Path from enum import IntEnum, auto -from typing import Any, IO, List, Optional +from io import BufferedWriter +from typing import Any, BinaryIO, Callable, IO, Dict, List, Optional, Sequence, Tuple, Union # # constants @@ -71,35 +75,35 @@ KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world" class MODEL_ARCH(IntEnum): - LLAMA = auto() - FALCON = auto() - GPT2 = auto() - GPTJ = auto() - GPTNEOX = auto() - MPT = auto() + LLAMA : int = auto() + FALCON : int = auto() + GPT2 : int = auto() + GPTJ : int = auto() + GPTNEOX: int = auto() + MPT : int = auto() class MODEL_TENSOR(IntEnum): - TOKEN_EMBD = auto() - POS_EMBD = auto() - OUTPUT = auto() - OUTPUT_NORM = auto() - ROPE_FREQS = auto() - ATTN_Q = auto() - ATTN_K = auto() - ATTN_V = auto() - ATTN_QKV = auto() - ATTN_OUT = auto() - ATTN_NORM = auto() - ATTN_NORM_2 = auto() - ATTN_ROT_EMBD = auto() - FFN_GATE = auto() - FFN_DOWN = auto() - FFN_UP = auto() - FFN_NORM = auto() - - -MODEL_ARCH_NAMES = { + TOKEN_EMBD : int = auto() + POS_EMBD : int = auto() + OUTPUT : int = auto() + OUTPUT_NORM : int = auto() + ROPE_FREQS : int = auto() + ATTN_Q : int = auto() + ATTN_K : int = auto() + ATTN_V : int = auto() + ATTN_QKV : int = auto() + ATTN_OUT : int = auto() + ATTN_NORM : int = auto() + ATTN_NORM_2 : int = auto() + ATTN_ROT_EMBD: int = auto() + FFN_GATE : int = auto() + FFN_DOWN : int = auto() + FFN_UP : int = auto() + FFN_NORM : int = auto() + + +MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = { MODEL_ARCH.LLAMA: "llama", MODEL_ARCH.FALCON: "falcon", MODEL_ARCH.GPT2: "gpt2", @@ -108,7 +112,7 @@ MODEL_ARCH_NAMES = { MODEL_ARCH.MPT: "mpt", } -MODEL_TENSOR_NAMES = { +MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = { MODEL_ARCH.LLAMA: { MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.OUTPUT_NORM: "output_norm", @@ -154,7 +158,7 @@ MODEL_TENSOR_NAMES = { } # tensors that will not be serialized -MODEL_TENSOR_SKIP = { +MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = { MODEL_ARCH.LLAMA: [ MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, @@ -162,167 +166,198 @@ MODEL_TENSOR_SKIP = { } -# TODO: the following helper functions should be removed -# instead, get_tensor_name_map should return tuples of (name, MODEL_TENSOR) -# however, my Python is very bad, and I couldn't figure out how to do this, hence these functions -# REMOVE -def should_skip_tensor_TMP(arch: MODEL_ARCH, n_blocks: int, name: str) -> bool: - for skip in MODEL_TENSOR_SKIP.get(arch, []): - for i in range(n_blocks): - if name == MODEL_TENSOR_NAMES[arch][skip].format(bid=i): - return True - - return False - - -def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> dict: - tensor_map = {} - - # Token embeddings - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.TOKEN_EMBD, None) - - tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox - tensor_map["transformer.wte"] = mapped_to # gpt2 mpt - tensor_map["transformer.word_embeddings"] = mapped_to # falcon - tensor_map["model.embed_tokens"] = mapped_to # llama-hf - tensor_map["tok_embeddings"] = mapped_to # llama-pth - - # Position embeddings - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.POS_EMBD, None) - - tensor_map["transformer.wpe"] = mapped_to # gpt2 - - # Output - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT, None) - - tensor_map["embed_out"] = mapped_to # gptneox - tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf - tensor_map["output"] = mapped_to # llama-pth - - # Output norm - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT_NORM, None) - - tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox - tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon - tensor_map["transformer.norm_f"] = mapped_to # mpt - tensor_map["model.norm"] = mapped_to # llama-hf - tensor_map["norm"] = mapped_to # llama-pth - - # Rope frequencies - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ROPE_FREQS, None) - - tensor_map["rope.freqs"] = mapped_to # llama-pth - - # Attention and feed-forward blocks - for i in range(0, n_blocks): +class TensorNameMap: + mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = { + # Token embeddings + MODEL_TENSOR.TOKEN_EMBD: ( + "gpt_neox.embed_in", # gptneox + "transformer.wte", # gpt2 mpt + "transformer.word_embeddings", # falcon + "model.embed_tokens", # llama-hf + "tok_embeddings", # llama-pth + ), + + # Position embeddings + MODEL_TENSOR.POS_EMBD: ( + "transformer.wpe", # gpt2 + ), + + # Output + MODEL_TENSOR.OUTPUT: ( + "embed_out", # gptneox + "lm_head", # gpt2 mpt falcon llama-hf + "output", # llama-pth + ), + + # Output norm + MODEL_TENSOR.OUTPUT_NORM: ( + "gpt_neox.final_layer_norm", # gptneox + "transformer.ln_f", # gpt2 falcon + "model.norm", # llama-hf + "norm", # llama-pth + ), + + # Rope frequencies + MODEL_TENSOR.ROPE_FREQS: ( + "rope.freqs", # llama-pth + ), + } + + block_mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = { # Attention norm - # TODO: is there are simpler way to write these 2 lines in Python? - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM, None) - mapped_to = mapped_to.format(bid=i) if mapped_to else None - - tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b - tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b - tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth + MODEL_TENSOR.ATTN_NORM: ( + "gpt_neox.layers.{bid}.input_layernorm", # gptneox + "transformer.h.{bid}.ln_1", # gpt2 + "transformer.blocks.{bid}.norm_1", # mpt + "transformer.h.{bid}.input_layernorm", # falcon7b + "transformer.h.{bid}.ln_mlp", # falcon40b + "model.layers.{bid}.input_layernorm", # llama-hf + "layers.{bid}.attention_norm", # llama-pth + ), # Attention norm 2 - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM_2, None) - mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None - - tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b + MODEL_TENSOR.ATTN_NORM_2: ( + "transformer.h.{bid}.ln_attn", # falcon40b + ), # Attention query-key-value - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_QKV, None) - mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None - - tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon + MODEL_TENSOR.ATTN_QKV: ( + "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox + "transformer.h.{bid}.attn.c_attn", # gpt2 + "transformer.blocks.{bid}.attn.Wqkv", # mpt + "transformer.h.{bid}.self_attention.query_key_value", # falcon + ), # Attention query - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_Q, None) - mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None - - tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth + MODEL_TENSOR.ATTN_Q: ( + "model.layers.{bid}.self_attn.q_proj", # llama-hf + "layers.{bid}.attention.wq", # llama-pth + ), # Attention key - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_K, None) - mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None - - tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth + MODEL_TENSOR.ATTN_K: ( + "model.layers.{bid}.self_attn.k_proj", # llama-hf + "layers.{bid}.attention.wk", # llama-pth + ), # Attention value - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_V, None) - mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None - - tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth + MODEL_TENSOR.ATTN_V: ( + "model.layers.{bid}.self_attn.v_proj", # llama-hf + "layers.{bid}.attention.wv", # llama-pth + ), # Attention output - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_OUT, None) - mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None - - tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon - tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth + MODEL_TENSOR.ATTN_OUT: ( + "gpt_neox.layers.{bid}.attention.dense", # gptneox + "transformer.h.{bid}.attn.c_proj", # gpt2 + "transformer.blocks.{bid}.attn.out_proj", # mpt + "transformer.h.{bid}.self_attention.dense", # falcon + "model.layers.{bid}.self_attn.o_proj", # llama-hf + "layers.{bid}.attention.wo", # llama-pth + ), # Rotary embeddings - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_ROT_EMBD, None) - mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None - - tensor_map["model.layers."+str(i)+".self_attn.rotary_emb.inv_freq"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.inner_attention.rope.freqs"] = mapped_to # llama-pth + MODEL_TENSOR.ATTN_ROT_EMBD: ( + "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf + "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth + ), # Feed-forward norm - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_NORM, None) - mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None - - tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt - tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth + MODEL_TENSOR.FFN_NORM: ( + "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox + "transformer.h.{bid}.ln_2", # gpt2 + "transformer.blocks.{bid}.norm_2", # mpt + "model.layers.{bid}.post_attention_layernorm", # llama-hf + "layers.{bid}.ffn_norm", # llama-pth + ), # Feed-forward up - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_UP, None) - mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None - - tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon - tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth + MODEL_TENSOR.FFN_UP: ( + "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox + "transformer.h.{bid}.mlp.c_fc", # gpt2 + "transformer.blocks.{bid}.ffn.up_proj", # mpt + "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon + "model.layers.{bid}.mlp.up_proj", # llama-hf + "layers.{bid}.feed_forward.w3", # llama-pth + ), # Feed-forward gate - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_GATE, None) - mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None - - tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth + MODEL_TENSOR.FFN_GATE: ( + "model.layers.{bid}.mlp.gate_proj", # llama-hf + "layers.{bid}.feed_forward.w1", # llama-pth + ), # Feed-forward down - mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_DOWN, None) - mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None - - tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon - tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth - - return tensor_map - + MODEL_TENSOR.FFN_DOWN: ( + "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox + "transformer.h.{bid}.mlp.c_proj", # gpt2 + "transformer.blocks.{bid}.ffn.down_proj", # mpt + "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon + "model.layers.{bid}.mlp.down_proj", # llama-hf + "layers.{bid}.feed_forward.w2", # llama-pth + ), + } + + mapping: Dict[str, Tuple[MODEL_TENSOR, str]] + + tensor_names: Dict[MODEL_TENSOR, str] + + def __init__(self, arch: MODEL_ARCH, n_blocks: int): + mapping = self.mapping = {} + tensor_names = self.tensor_names = MODEL_TENSOR_NAMES[arch] + for tensor, keys in self.mappings_cfg.items(): + tensor_name = tensor_names.get(tensor) + if tensor_name is None: + continue + for key in keys: + mapping[key] = (tensor, tensor_name) + for bid in range(n_blocks): + for tensor, keys in self.block_mappings_cfg.items(): + tensor_name = tensor_names.get(tensor) + if tensor_name is None: + continue + tensor_name = tensor_name.format(bid = bid) + for key in keys: + key = key.format(bid = bid) + mapping[key] = (tensor, tensor_name) + + def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[Tuple[MODEL_TENSOR, str]]: + result = self.mapping.get(key) + if result is not None: + return result + for suffix in try_suffixes: + if key.endswith(suffix): + result = self.mapping.get(key[:-len(suffix)]) + if result is not None: + return (result[0], result[1] + suffix) + return None + + def get_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[str]: + result = self.get_type_and_name(key, try_suffixes = try_suffixes) + if result is None: + return None + return result[1] + + def get_type(self, key: str, try_suffixes: Sequence[str]) -> Optional[MODEL_TENSOR]: + result = self.get_type_and_name(key, try_suffixes = try_suffixes) + if result is None: + return None + return result[0] + + def __getitem__(self, key: str) -> str: + try: + return self.mapping[key][1] + except KeyError: + raise KeyError(key) + + def __contains__(self, key: str) -> bool: + return key in self.mapping + + def __repr__(self) -> str: + return repr(self.mapping) + +def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap: + return TensorNameMap(arch, n_blocks) class TokenType(IntEnum): NORMAL = 1 @@ -388,15 +423,21 @@ class GGUFValueType(IntEnum): class GGUFWriter: - def __init__(self, path: str, arch: str, use_temp_file = True): + fout: BufferedWriter + arch: str + offset_tensor = 0 + data_alignment = GGUF_DEFAULT_ALIGNMENT + kv_data = b"" + kv_data_count = 0 + ti_data = b"" + ti_data_count = 0 + use_temp_file: bool + temp_file: Optional[tempfile.SpooledTemporaryFile[bytes]] = None + tensors: List[Tuple[np.ndarray[Any, Any], int]] + + def __init__(self, path: Union[os.PathLike[str], str], arch: str, use_temp_file = True): self.fout = open(path, "wb") self.arch = arch - self.offset_tensor = 0 - self.data_alignment = GGUF_DEFAULT_ALIGNMENT - self.kv_data = b"" - self.kv_data_count = 0 - self.ti_data = b"" - self.ti_data_count = 0 self.add_architecture() self.use_temp_file = use_temp_file self.tensors = [] @@ -470,14 +511,27 @@ class GGUFWriter: self.add_key(key) self.add_val(val, GGUFValueType.STRING) - def add_array(self, key: str, val: list): - if not isinstance(val, list): - raise ValueError("Value must be a list for array type") + def add_array(self, key: str, val: Sequence[Any]): + if not isinstance(val, Sequence): + raise ValueError("Value must be a sequence for array type") self.add_key(key) self.add_val(val, GGUFValueType.ARRAY) - def add_val(self: str, val: Any, vtype: GGUFValueType = None, add_vtype: bool = True): + _simple_value_packing = { + GGUFValueType.UINT8: "<B", + GGUFValueType.INT8: "<b", + GGUFValueType.UINT16: "<H", + GGUFValueType.INT16: "<h", + GGUFValueType.UINT32: "<I", + GGUFValueType.INT32: "<i", + GGUFValueType.FLOAT32: "<f", + GGUFValueType.UINT64: "<Q", + GGUFValueType.INT64: "<q", + GGUFValueType.FLOAT64: "<d", + GGUFValueType.BOOL: "?" , + } + def add_val(self, val: Any, vtype: Optional[GGUFValueType] = None, add_vtype: bool = True): if vtype is None: vtype = GGUFValueType.get_type(val) @@ -485,47 +539,29 @@ class GGUFWriter: self.kv_data += struct.pack("<I", vtype) self.kv_data_count += 1 - if vtype == GGUFValueType.UINT8: - self.kv_data += struct.pack("<B", val) - elif vtype == GGUFValueType.INT8: - self.kv_data += struct.pack("<b", val) - elif vtype == GGUFValueType.UINT16: - self.kv_data += struct.pack("<H", val) - elif vtype == GGUFValueType.INT16: - self.kv_data += struct.pack("<h", val) - elif vtype == GGUFValueType.UINT32: - self.kv_data += struct.pack("<I", val) - elif vtype == GGUFValueType.INT32: - self.kv_data += struct.pack("<i", val) - elif vtype == GGUFValueType.FLOAT32: - self.kv_data += struct.pack("<f", val) - elif vtype == GGUFValueType.UINT64: - self.kv_data += struct.pack("<Q", val) - elif vtype == GGUFValueType.INT64: - self.kv_data += struct.pack("<q", val) - elif vtype == GGUFValueType.FLOAT64: - self.kv_data += struct.pack("<d", val) - elif vtype == GGUFValueType.BOOL: - self.kv_data += struct.pack("?", val) + pack_fmt = self._simple_value_packing.get(vtype) + if pack_fmt is not None: + self.kv_data += struct.pack(pack_fmt, val) elif vtype == GGUFValueType.STRING: encoded_val = val.encode("utf8") if isinstance(val, str) else val self.kv_data += struct.pack("<Q", len(encoded_val)) self.kv_data += encoded_val - elif vtype == GGUFValueType.ARRAY: - ltype = set([GGUFValueType.get_type(item) for item in val]) - assert len(ltype) == 1, "All items in a GGUF array should be of the same type" - self.kv_data += struct.pack("<I", list(ltype)[0]) + elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and len(val) > 0: + ltype = GGUFValueType.get_type(val[0]) + if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): + raise ValueError("All items in a GGUF array should be of the same type") + self.kv_data += struct.pack("<I", ltype) self.kv_data += struct.pack("<Q", len(val)) for item in val: self.add_val(item, add_vtype=False) else: - raise ValueError("Invalid GGUF metadata value type") + raise ValueError("Invalid GGUF metadata value type or value") @staticmethod def ggml_pad(x: int, n: int) -> int: return ((x + n - 1) // n) * n - def add_tensor_info(self, name: str, tensor_shape: np.ndarray, tensor_dtype: np.dtype, tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None): + def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: Union[np.dtype[np.float16], np.dtype[np.float32]], tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None): assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now" encoded_name = name.encode("utf8") @@ -544,16 +580,18 @@ class GGUFWriter: self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) self.ti_data_count += 1 - def add_tensor(self, name: str, tensor: np.ndarray, raw_shape: Optional[np.ndarray] = None, raw_dtype: Optional[GGMLQuantizationType] = None): - if self.use_temp_file and not hasattr(self, "temp_file"): - self.temp_file = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024) - self.temp_file.seek(0) + def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Optional[Sequence[int]] = None, raw_dtype: Optional[GGMLQuantizationType] = None): + if self.use_temp_file and self.temp_file is None: + fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024) + fp.seek(0) + self.temp_file = fp - self.add_tensor_info(name, raw_shape if raw_shape is not None else tensor.shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) + shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape + self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes - if not self.use_temp_file: + if self.temp_file is None: self.tensors.append((tensor, pad)) return @@ -562,25 +600,22 @@ class GGUFWriter: if pad != 0: self.temp_file.write(bytes([0] * pad)) - def write_tensor_data(self, tensor: np.ndarray): - pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell() + def write_padding(self, fp: BinaryIO, n: int, align: Optional[int] = None): + pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n if pad != 0: - self.fout.write(bytes([0] * pad)) + fp.write(bytes([0] * pad)) + def write_tensor_data(self, tensor: np.ndarray[Any, Any]): + self.write_padding(self.fout, self.fout.tell()) tensor.tofile(self.fout) - - pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes - if pad != 0: - self.fout.write(bytes([0] * pad)) + self.write_padding(self.fout, tensor.nbytes) def write_tensors_to_file(self): self.write_ti_data_to_file() - pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell() - if pad != 0: - self.fout.write(bytes([0] * pad)) + self.write_padding(self.fout, self.fout.tell()) - if not self.use_temp_file: + if self.temp_file is None: for (currtensor, currpad) in self.tensors: currtensor.tofile(self.fout) if currpad != 0: @@ -654,10 +689,6 @@ class GGUFWriter: self.add_bool( KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) - def add_tensor_data_layout(self, layout: str): - self.add_string( - KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) - def add_head_count(self, count: int): self.add_uint32( KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count) @@ -695,16 +726,16 @@ class GGUFWriter: def add_tokenizer_model(self, model: str): self.add_string(KEY_TOKENIZER_MODEL, model) - def add_token_list(self, tokens: List): + def add_token_list(self, tokens: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]): self.add_array(KEY_TOKENIZER_LIST, tokens) - def add_token_merges(self, merges: List): + def add_token_merges(self, merges: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]): self.add_array(KEY_TOKENIZER_MERGES, merges) - def add_token_types(self, types: List[int]): + def add_token_types(self, types: Union[Sequence[TokenType], Sequence[int]]): self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types) - def add_token_scores(self, scores: List[float]): + def add_token_scores(self, scores: Sequence[float]): self.add_array(KEY_TOKENIZER_SCORES, scores) def add_bos_token_id(self, id: int): @@ -723,6 +754,84 @@ class GGUFWriter: self.add_uint32(KEY_TOKENIZER_PAD_ID, id) +class SpecialVocab: + load_merges: bool = False + merges: List[str] = [] + special_token_types: Tuple[str, ...] = tuple(('bos', 'eos', 'unk', 'sep', 'pad')) + special_token_ids: Dict[str, int] = {} + + def __init__(self, path: Path, load_merges: bool = False, special_token_types: Optional[Tuple[str, ...]] = None): + self.special_token_ids = {} + self.load_merges = load_merges + if special_token_types is not None: + self.special_token_types = special_token_types + self.load(path) + + def load(self, path: Path): + if not self.try_load_from_tokenizer_json(path): + self.try_load_from_config_json(path) + + def try_load_from_tokenizer_json(self, path: Path) -> bool: + tokenizer_file = path / 'tokenizer.json' + if not tokenizer_file.is_file(): + return False + with open(tokenizer_file, 'r', encoding = 'utf-8') as f: + tokenizer = json.load(f) + if self.load_merges: + merges = tokenizer.get('model', {}).get('merges') + if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str): + self.merges = merges + tokenizer_config_file = path / 'tokenizer_config.json' + added_tokens = tokenizer.get('added_tokens') + if added_tokens is None or not tokenizer_config_file.is_file(): + return True + with open(tokenizer_config_file, 'r', encoding = 'utf-8') as f: + tokenizer_config = json.load(f) + for typ in self.special_token_types: + entry = tokenizer_config.get(f'{typ}_token') + if isinstance(entry, str): + tc_content = entry + elif isinstance(entry, dict): + entry_content = entry.get('content') + if not isinstance(entry_content, str): + continue + tc_content = entry_content + else: + continue + for maybe_token_id in (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content): + if isinstance(maybe_token_id, int): + self.special_token_ids[typ] = maybe_token_id + break + return True + + def try_load_from_config_json(self, path: Path) -> bool: + config_file = path / 'config.json' + if not config_file.is_file(): + return False + with open(config_file, 'r', encoding = 'utf-8') as f: + config = json.load(f) + for typ in self.special_token_types: + maybe_token_id = config.get(f'{typ}_token_id') + if isinstance(maybe_token_id, int): + self.special_token_ids[typ] = maybe_token_id + return True + + def add_to_gguf(self, gw: GGUFWriter): + if len(self.merges) > 0: + print(f'gguf: Adding {len(self.merges)} merge(s).') + gw.add_token_merges(self.merges) + for typ, tokid in self.special_token_ids.items(): + handler: Optional[Callable[[int], None]] = getattr(gw, f'add_{typ}_token_id', None) + if handler is None: + print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping') + continue + print(f'gguf: Setting special token type {typ} to {tokid}') + handler(tokid) + + def __repr__(self): + return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids if self.special_token_ids else "unset"}>' + + # Example usage: if __name__ == "__main__": # Example usage with a file diff --git a/gguf-py/gguf/py.typed b/gguf-py/gguf/py.typed new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/gguf-py/gguf/py.typed diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index cc70e28b..c66b069f 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -5,6 +5,7 @@ description = "Write ML models in GGUF for GGML" authors = ["GGML <ggml@ggml.ai>"] packages = [ {include = "gguf"}, + {include = "gguf/py.typed"}, ] readme = "README.md" homepage = "https://ggml.ai" |