summaryrefslogtreecommitdiff
path: root/gguf-py/gguf/gguf_reader.py
diff options
context:
space:
mode:
Diffstat (limited to 'gguf-py/gguf/gguf_reader.py')
-rw-r--r--gguf-py/gguf/gguf_reader.py35
1 files changed, 28 insertions, 7 deletions
diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py
index e48bc00c..e8e61abf 100644
--- a/gguf-py/gguf/gguf_reader.py
+++ b/gguf-py/gguf/gguf_reader.py
@@ -67,8 +67,9 @@ class ReaderTensor(NamedTuple):
class GGUFReader:
# I - same as host, S - swapped
- byte_order: Literal['I'] | Literal['S'] = 'I'
+ byte_order: Literal['I', 'S'] = 'I'
alignment: int = GGUF_DEFAULT_ALIGNMENT
+ data_offset: int
# Note: Internal helper, API may change.
gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
@@ -85,12 +86,16 @@ class GGUFReader:
GGUFValueType.BOOL: np.bool_,
}
- def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
+ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
self.data = np.memmap(path, mode = mode)
offs = 0
+
+ # Check for GGUF magic
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
raise ValueError('GGUF magic invalid')
offs += 4
+
+ # Check GGUF version
temp_version = self._get(offs, np.uint32)
if temp_version[0] & 65535 == 0:
# If we get 0 here that means it's (probably) a GGUF file created for
@@ -103,12 +108,16 @@ class GGUFReader:
self.fields: OrderedDict[str, ReaderField] = OrderedDict()
self.tensors: list[ReaderTensor] = []
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
+
+ # Check tensor count and kv count
temp_counts = self._get(offs, np.uint64, 2)
offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
tensor_count, kv_count = temp_counts
offs = self._build_fields(offs, kv_count)
- offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
+
+ # Build Tensor Info Fields
+ offs, tensors_fields = self._build_tensor_info(offs, tensor_count)
new_align = self.fields.get('general.alignment')
if new_align is not None:
if new_align.types != [GGUFValueType.UINT32]:
@@ -117,6 +126,7 @@ class GGUFReader:
padding = offs % self.alignment
if padding != 0:
offs += self.alignment - padding
+ self.data_offset = offs
self._build_tensors(offs, tensors_fields)
_DT = TypeVar('_DT', bound = npt.DTypeLike)
@@ -130,7 +140,7 @@ class GGUFReader:
return self.tensors[idx]
def _get(
- self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
+ self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
) -> npt.NDArray[Any]:
count = int(count)
itemsize = int(np.empty([], dtype = dtype).itemsize)
@@ -193,18 +203,29 @@ class GGUFReader:
# We can't deal with this one.
raise ValueError('Unknown/unhandled field type {gtype}')
- def _get_tensor(self, orig_offs: int) -> ReaderField:
+ def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
offs = orig_offs
+
+ # Get Tensor Name
name_len, name_data = self._get_str(offs)
offs += int(name_len.nbytes + name_data.nbytes)
+
+ # Get Tensor Dimensions Count
n_dims = self._get(offs, np.uint32)
offs += int(n_dims.nbytes)
+
+ # Get Tensor Dimension Array
dims = self._get(offs, np.uint64, n_dims[0])
offs += int(dims.nbytes)
+
+ # Get Tensor Encoding Scheme Type
raw_dtype = self._get(offs, np.uint32)
offs += int(raw_dtype.nbytes)
+
+ # Get Tensor Offset
offset_tensor = self._get(offs, np.uint64)
offs += int(offset_tensor.nbytes)
+
return ReaderField(
orig_offs,
str(bytes(name_data), encoding = 'utf-8'),
@@ -233,10 +254,10 @@ class GGUFReader:
offs += field_size
return offs
- def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
+ def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
tensor_fields = []
for _ in range(count):
- field = self._get_tensor(offs)
+ field = self._get_tensor_info_field(offs)
offs += sum(int(part.nbytes) for part in field.parts)
tensor_fields.append(field)
return offs, tensor_fields