diff options
Diffstat (limited to 'gguf-py/gguf/gguf_writer.py')
-rw-r--r-- | gguf-py/gguf/gguf_writer.py | 445 |
1 files changed, 351 insertions, 94 deletions
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a697f657..ba6f53cd 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -7,6 +7,8 @@ import struct import tempfile from dataclasses import dataclass from enum import Enum, auto +from math import prod +from pathlib import Path from io import BufferedWriter from typing import IO, Any, Sequence, Mapping from string import ascii_letters, digits @@ -31,6 +33,9 @@ from .quants import quant_shape_from_byte_shape logger = logging.getLogger(__name__) +SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf" + + @dataclass class TensorInfo: shape: Sequence[int] @@ -55,11 +60,11 @@ class WriterState(Enum): class GGUFWriter: - fout: BufferedWriter | None - path: os.PathLike[str] | str | None + fout: list[BufferedWriter] | None + path: Path | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None - tensors: dict[str, TensorInfo] - kv_data: dict[str, GGUFValue] + tensors: list[dict[str, TensorInfo]] + kv_data: list[dict[str, GGUFValue]] state: WriterState _simple_value_packing = { GGUFValueType.UINT8: "B", @@ -76,29 +81,89 @@ class GGUFWriter: } def __init__( - self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, - endianess: GGUFEndian = GGUFEndian.LITTLE, + self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False ): self.fout = None - self.path = path + self.path = Path(path) if path else None self.arch = arch self.endianess = endianess self.data_alignment = GGUF_DEFAULT_ALIGNMENT self.use_temp_file = use_temp_file self.temp_file = None - self.tensors = dict() - self.kv_data = dict() + self.tensors = [{}] + self.kv_data = [{}] + self.split_max_tensors = split_max_tensors + self.split_max_size = split_max_size + self.dry_run = dry_run + self.small_first_shard = small_first_shard logger.info("gguf: This GGUF file is for {0} Endian only".format( "Big" if self.endianess == GGUFEndian.BIG else "Little", )) self.state = WriterState.NO_FILE + if self.small_first_shard: + self.tensors.append({}) + self.add_architecture() - def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None: + def get_total_parameter_count(self) -> tuple[int, int, int, int]: + total_params = 0 + shared_params = 0 + expert_params = 0 + + expert_sum = 0 + n_expert_tensors = 0 + + last_lora_a: tuple[str, TensorInfo] | None = None + + for tensors in self.tensors: + for name, info in tensors.items(): + + shape = info.shape + + if name.endswith(".lora_a"): + last_lora_a = (name, info) + continue + elif name.endswith(".lora_b"): + if last_lora_a is None or last_lora_a[0] != name[:-1] + "a": + # Bail when the LoRA pair can't be found trivially + logger.warning("can't measure LoRA size correctly, tensor order is unusual") + return 0, 0, 0, 0 + else: + shape = (*shape[:-1], last_lora_a[1].shape[-1]) + + size = prod(shape) + + if "_exps." in name: + expert_params += (size // shape[-3]) + expert_sum += shape[-3] + n_expert_tensors += 1 + else: + shared_params += size + + total_params += size + + # Hopefully this should work even for variable-expert-count models + expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0 + + # Negate the total to signal it's likely not exact + if last_lora_a is not None: + total_params = -total_params + + # NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py + return total_params, shared_params, expert_params, expert_count + + def format_shard_names(self, path: Path) -> list[Path]: + if len(self.tensors) == 1: + return [path] + return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))] + + def open_output_file(self, path: Path | None = None) -> None: if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path): # allow calling this multiple times as long as the path is the same return + if self.state is not WriterState.NO_FILE: raise ValueError(f'Expected output file to be not yet opened, got {self.state}') @@ -106,22 +171,60 @@ class GGUFWriter: self.path = path if self.path is not None: - if self.fout is not None: - self.fout.close() - self.fout = open(self.path, "wb") + filenames = self.print_plan() + self.fout = [open(filename, "wb") for filename in filenames] self.state = WriterState.EMPTY - def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: + def print_plan(self) -> list[Path]: + logger.info("Writing the following files:") + assert self.path is not None + filenames = self.format_shard_names(self.path) + assert len(filenames) == len(self.tensors) + for name, tensors in zip(filenames, self.tensors): + logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}") + + if self.dry_run: + logger.info("Dry run, not writing files") + for name in filenames: + print(name) # noqa: NP100 + exit() + + return filenames + + def add_shard_kv_data(self) -> None: + if len(self.tensors) == 1: + return + + total_tensors = sum(len(t) for t in self.tensors) + assert self.fout is not None + total_splits = len(self.fout) + self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits)) + for i, kv_data in enumerate(self.kv_data): + kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16) + kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16) + kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32) + + def write_header_to_file(self, path: Path | None = None) -> None: + if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0): + logger.warning("Model fails split requirements, not splitting") + self.open_output_file(path) if self.state is not WriterState.EMPTY: raise ValueError(f'Expected output file to be empty, got {self.state}') - self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True) - self._write_packed("I", GGUF_VERSION) - self._write_packed("Q", len(self.tensors)) - self._write_packed("Q", len(self.kv_data)) - self.flush() + assert self.fout is not None + assert len(self.fout) == len(self.tensors) + assert len(self.kv_data) == 1 + + self.add_shard_kv_data() + + for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data): + fout.write(self._pack("<I", GGUF_MAGIC, skip_pack_prefix = True)) + fout.write(self._pack("I", GGUF_VERSION)) + fout.write(self._pack("Q", len(tensors))) + fout.write(self._pack("Q", len(kv_data))) + fout.flush() self.state = WriterState.HEADER def write_kv_data_to_file(self) -> None: @@ -129,13 +232,15 @@ class GGUFWriter: raise ValueError(f'Expected output file to contain the header, got {self.state}') assert self.fout is not None - kv_data = bytearray() + for fout, kv_data in zip(self.fout, self.kv_data): + kv_bytes = bytearray() + + for key, val in kv_data.items(): + kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) + kv_bytes += self._pack_val(val.value, val.type, add_vtype=True) - for key, val in self.kv_data.items(): - kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) - kv_data += self._pack_val(val.value, val.type, add_vtype=True) + fout.write(kv_bytes) - self.fout.write(kv_data) self.flush() self.state = WriterState.KV_DATA @@ -144,28 +249,29 @@ class GGUFWriter: raise ValueError(f'Expected output file to contain KV data, got {self.state}') assert self.fout is not None - ti_data = bytearray() - offset_tensor = 0 - - for name, ti in self.tensors.items(): - ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) - n_dims = len(ti.shape) - ti_data += self._pack("I", n_dims) - for i in range(n_dims): - ti_data += self._pack("Q", ti.shape[n_dims - 1 - i]) - ti_data += self._pack("I", ti.dtype) - ti_data += self._pack("Q", offset_tensor) - offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) - - self.fout.write(ti_data) - self.flush() + for fout, tensors in zip(self.fout, self.tensors): + ti_data = bytearray() + offset_tensor = 0 + + for name, ti in tensors.items(): + ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) + n_dims = len(ti.shape) + ti_data += self._pack("I", n_dims) + for j in range(n_dims): + ti_data += self._pack("Q", ti.shape[n_dims - 1 - j]) + ti_data += self._pack("I", ti.dtype) + ti_data += self._pack("Q", offset_tensor) + offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) + + fout.write(ti_data) + fout.flush() self.state = WriterState.TI_DATA def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: - if key in self.kv_data: + if any(key in kv_data for kv_data in self.kv_data): raise ValueError(f'Duplicated key name {key!r}') - self.kv_data[key] = GGUFValue(value=val, type=vtype) + self.kv_data[0][key] = GGUFValue(value=val, type=vtype) def add_uint8(self, key: str, val: int) -> None: self.add_key_value(key,val, GGUFValueType.UINT8) @@ -206,9 +312,6 @@ class GGUFWriter: self.add_key_value(key, val, GGUFValueType.STRING) def add_array(self, key: str, val: Sequence[Any]) -> None: - if not isinstance(val, Sequence): - raise ValueError("Value must be a sequence for array type") - self.add_key_value(key, val, GGUFValueType.ARRAY) @staticmethod @@ -222,7 +325,7 @@ class GGUFWriter: if self.state is not WriterState.NO_FILE: raise ValueError(f'Expected output file to be not yet opened, got {self.state}') - if name in self.tensors: + if any(name in tensors for tensors in self.tensors): raise ValueError(f'Duplicated tensor name {name!r}') if raw_dtype is None: @@ -247,7 +350,18 @@ class GGUFWriter: if tensor_dtype == np.uint8: tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) - self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) + # make sure there is at least one tensor before splitting + if len(self.tensors[-1]) > 0: + if ( # split when over tensor limit + self.split_max_tensors != 0 + and len(self.tensors[-1]) >= self.split_max_tensors + ) or ( # split when over size limit + self.split_max_size != 0 + and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size + ): + self.tensors.append({}) + + self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, @@ -264,7 +378,7 @@ class GGUFWriter: self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype) if self.temp_file is None: - self.tensors[name].tensor = tensor + self.tensors[-1][name].tensor = tensor return tensor.tofile(self.temp_file) @@ -282,9 +396,24 @@ class GGUFWriter: if self.endianess == GGUFEndian.BIG: tensor.byteswap(inplace=True) - self.write_padding(self.fout, self.fout.tell()) - tensor.tofile(self.fout) - self.write_padding(self.fout, tensor.nbytes) + + file_id = -1 + for i, tensors in enumerate(self.tensors): + if len(tensors) > 0: + file_id = i + break + + fout = self.fout[file_id] + + # pop the first tensor info + # TODO: cleaner way to get the first key + first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0] + ti = self.tensors[file_id].pop(first_tensor_name) + assert ti.nbytes == tensor.nbytes + + self.write_padding(fout, fout.tell()) + tensor.tofile(fout) + self.write_padding(fout, tensor.nbytes) self.state = WriterState.WEIGHTS @@ -293,31 +422,43 @@ class GGUFWriter: assert self.fout is not None - self.write_padding(self.fout, self.fout.tell()) + for fout in self.fout: + self.write_padding(fout, fout.tell()) if self.temp_file is None: + shard_bar = None bar = None if progress: from tqdm import tqdm - total_bytes = sum(t.nbytes for t in self.tensors.values()) + total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values()) + if len(self.fout) > 1: + shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - # relying on the fact that Python dicts preserve insertion order (since 3.7) - for ti in self.tensors.values(): - assert ti.tensor is not None # can only iterate once over the tensors - assert ti.tensor.nbytes == ti.nbytes - ti.tensor.tofile(self.fout) - if bar is not None: - bar.update(ti.nbytes) - self.write_padding(self.fout, ti.nbytes) - ti.tensor = None + for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)): + if shard_bar is not None: + shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})") + total = sum(ti.nbytes for ti in tensors.values()) + shard_bar.reset(total=(total if total > 0 else None)) + + # relying on the fact that Python dicts preserve insertion order (since 3.7) + for ti in tensors.values(): + assert ti.tensor is not None # can only iterate once over the tensors + assert ti.tensor.nbytes == ti.nbytes + ti.tensor.tofile(fout) + if shard_bar is not None: + shard_bar.update(ti.nbytes) + if bar is not None: + bar.update(ti.nbytes) + self.write_padding(fout, ti.nbytes) + ti.tensor = None else: self.temp_file.seek(0) - shutil.copyfileobj(self.temp_file, self.fout) + shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1]) self.flush() self.temp_file.close() @@ -325,53 +466,129 @@ class GGUFWriter: def flush(self) -> None: assert self.fout is not None - self.fout.flush() + for fout in self.fout: + fout.flush() def close(self) -> None: if self.fout is not None: - self.fout.close() + for fout in self.fout: + fout.close() self.fout = None + def add_type(self, type_name: str) -> None: + self.add_string(Keys.General.TYPE, type_name) + def add_architecture(self) -> None: self.add_string(Keys.General.ARCHITECTURE, self.arch) + def add_quantization_version(self, quantization_version: int) -> None: + self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version) + + def add_custom_alignment(self, alignment: int) -> None: + self.data_alignment = alignment + self.add_uint32(Keys.General.ALIGNMENT, alignment) + + def add_file_type(self, ftype: int) -> None: + self.add_uint32(Keys.General.FILE_TYPE, ftype) + + def add_name(self, name: str) -> None: + self.add_string(Keys.General.NAME, name) + def add_author(self, author: str) -> None: self.add_string(Keys.General.AUTHOR, author) def add_version(self, version: str) -> None: self.add_string(Keys.General.VERSION, version) - def add_tensor_data_layout(self, layout: str) -> None: - self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) + def add_organization(self, organization: str) -> None: + self.add_string(Keys.General.ORGANIZATION, organization) - def add_url(self, url: str) -> None: - self.add_string(Keys.General.URL, url) + def add_finetune(self, finetune: str) -> None: + self.add_string(Keys.General.FINETUNE, finetune) + + def add_basename(self, basename: str) -> None: + self.add_string(Keys.General.BASENAME, basename) def add_description(self, description: str) -> None: self.add_string(Keys.General.DESCRIPTION, description) - def add_licence(self, licence: str) -> None: - self.add_string(Keys.General.LICENSE, licence) + def add_quantized_by(self, quantized: str) -> None: + self.add_string(Keys.General.QUANTIZED_BY, quantized) + + def add_size_label(self, size_label: str) -> None: + self.add_string(Keys.General.SIZE_LABEL, size_label) + + def add_license(self, license: str) -> None: + self.add_string(Keys.General.LICENSE, license) + + def add_license_name(self, license: str) -> None: + self.add_string(Keys.General.LICENSE_NAME, license) + + def add_license_link(self, license: str) -> None: + self.add_string(Keys.General.LICENSE_LINK, license) + + def add_url(self, url: str) -> None: + self.add_string(Keys.General.URL, url) + + def add_doi(self, doi: str) -> None: + self.add_string(Keys.General.DOI, doi) + + def add_uuid(self, uuid: str) -> None: + self.add_string(Keys.General.UUID, uuid) + + def add_repo_url(self, repo_url: str) -> None: + self.add_string(Keys.General.REPO_URL, repo_url) def add_source_url(self, url: str) -> None: self.add_string(Keys.General.SOURCE_URL, url) - def add_source_hf_repo(self, repo: str) -> None: - self.add_string(Keys.General.SOURCE_HF_REPO, repo) + def add_source_doi(self, doi: str) -> None: + self.add_string(Keys.General.SOURCE_DOI, doi) - def add_file_type(self, ftype: int) -> None: - self.add_uint32(Keys.General.FILE_TYPE, ftype) + def add_source_uuid(self, uuid: str) -> None: + self.add_string(Keys.General.SOURCE_UUID, uuid) - def add_name(self, name: str) -> None: - self.add_string(Keys.General.NAME, name) + def add_source_repo_url(self, repo_url: str) -> None: + self.add_string(Keys.General.SOURCE_REPO_URL, repo_url) - def add_quantization_version(self, quantization_version: int) -> None: - self.add_uint32( - Keys.General.QUANTIZATION_VERSION, quantization_version) + def add_base_model_count(self, source_count: int) -> None: + self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count) - def add_custom_alignment(self, alignment: int) -> None: - self.data_alignment = alignment - self.add_uint32(Keys.General.ALIGNMENT, alignment) + def add_base_model_name(self, source_id: int, name: str) -> None: + self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name) + + def add_base_model_author(self, source_id: int, author: str) -> None: + self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author) + + def add_base_model_version(self, source_id: int, version: str) -> None: + self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version) + + def add_base_model_organization(self, source_id: int, organization: str) -> None: + self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization) + + def add_base_model_url(self, source_id: int, url: str) -> None: + self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url) + + def add_base_model_doi(self, source_id: int, doi: str) -> None: + self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi) + + def add_base_model_uuid(self, source_id: int, uuid: str) -> None: + self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid) + + def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None: + self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url) + + def add_tags(self, tags: Sequence[str]) -> None: + self.add_array(Keys.General.TAGS, tags) + + def add_languages(self, languages: Sequence[str]) -> None: + self.add_array(Keys.General.LANGUAGES, languages) + + def add_datasets(self, datasets: Sequence[str]) -> None: + self.add_array(Keys.General.DATASETS, datasets) + + def add_tensor_data_layout(self, layout: str) -> None: + self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) def add_vocab_size(self, size: int) -> None: self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size) @@ -388,8 +605,11 @@ class GGUFWriter: def add_leading_dense_block_count(self, length: int) -> None: self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length) - def add_feed_forward_length(self, length: int) -> None: - self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) + def add_feed_forward_length(self, length: int | Sequence[int]) -> None: + if isinstance(length, int): + self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) + else: + self.add_array(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) def add_expert_feed_forward_length(self, length: int) -> None: self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length) @@ -400,11 +620,20 @@ class GGUFWriter: def add_parallel_residual(self, use: bool) -> None: self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) - def add_head_count(self, count: int) -> None: - self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) + def add_decoder_start_token_id(self, id: int) -> None: + self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id) - def add_head_count_kv(self, count: int) -> None: - self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + def add_head_count(self, count: int | Sequence[int]) -> None: + if isinstance(count, int): + self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) + else: + self.add_array(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) + + def add_head_count_kv(self, count: int | Sequence[int]) -> None: + if isinstance(count, int): + self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + else: + self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) def add_key_length(self, length: int) -> None: self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length) @@ -421,6 +650,12 @@ class GGUFWriter: def add_logit_scale(self, value: float) -> None: self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value) + def add_attn_logit_softcapping(self, value: float) -> None: + self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + + def add_final_logit_softcapping(self, value: float) -> None: + self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + def add_expert_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count) @@ -448,6 +683,12 @@ class GGUFWriter: def add_kv_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length) + def add_relative_attn_buckets_count(self, value: int) -> None: + self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value) + + def add_sliding_window(self, value: int) -> None: + self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) @@ -538,6 +779,12 @@ class GGUFWriter: def add_add_space_prefix(self, value: bool) -> None: self.add_bool(Keys.Tokenizer.ADD_PREFIX, value) + def add_remove_extra_whitespaces(self, value: bool) -> None: + self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value) + + def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: + self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None @@ -599,9 +846,12 @@ class GGUFWriter: kv_data += self._pack("Q", len(encoded_val)) kv_data += encoded_val elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: - ltype = GGUFValueType.get_type(val[0]) - if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): - raise ValueError("All items in a GGUF array should be of the same type") + if isinstance(val, bytes): + ltype = GGUFValueType.UINT8 + else: + ltype = GGUFValueType.get_type(val[0]) + if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): + raise ValueError("All items in a GGUF array should be of the same type") kv_data += self._pack("I", ltype) kv_data += self._pack("Q", len(val)) for item in val: @@ -611,6 +861,13 @@ class GGUFWriter: return kv_data - def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None: - assert self.fout is not None - self.fout.write(self._pack(fmt, value, skip_pack_prefix)) + @staticmethod + def format_n_bytes_to_str(num: int) -> str: + if num == 0: + return "negligible - metadata only" + fnum = float(num) + for unit in ("", "K", "M", "G"): + if abs(fnum) < 1000.0: + return f"{fnum:3.1f}{unit}" + fnum /= 1000.0 + return f"{fnum:.1f}T - over 1TB, split recommended" |