summaryrefslogtreecommitdiff
path: root/gguf-py/gguf/gguf_writer.py
diff options
context:
space:
mode:
Diffstat (limited to 'gguf-py/gguf/gguf_writer.py')
-rw-r--r--gguf-py/gguf/gguf_writer.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index e49c5db6..9c1eeac3 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -196,9 +196,6 @@ class GGUFWriter:
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}')
- if raw_dtype is None and tensor_dtype not in (np.float32, np.float16):
- raise ValueError("Only F32 and F16 tensors are supported for now")
-
encoded_name = name.encode("utf8")
self.ti_data += self._pack("Q", len(encoded_name))
self.ti_data += encoded_name
@@ -207,7 +204,18 @@ class GGUFWriter:
for i in range(n_dims):
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
if raw_dtype is None:
- dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
+ if tensor_shape == np.float32:
+ dtype = GGMLQuantizationType.F32
+ elif tensor_dtype == np.float16:
+ dtype = GGMLQuantizationType.F16
+ elif tensor_dtype == np.int8:
+ dtype = GGMLQuantizationType.I8
+ elif tensor_dtype == np.int16:
+ dtype = GGMLQuantizationType.I16
+ elif tensor_dtype == np.int32:
+ dtype = GGMLQuantizationType.I32
+ else:
+ raise ValueError("Only F32, F16, I8, I16, I32 tensors are supported for now")
else:
dtype = raw_dtype
self.ti_data += self._pack("I", dtype)