summaryrefslogtreecommitdiff
path: root/convert-hf-to-gguf.py
diff options
context:
space:
mode:
authorkunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>2024-03-01 06:08:08 -0800
committerGitHub <noreply@github.com>2024-03-01 16:08:08 +0200
commite7433867288d2f142cffe596f3751bda5d7ee2c7 (patch)
tree143c0e6537711cc23994a0b59fc17c19db2db678 /convert-hf-to-gguf.py
parentf49a5356865ced0eca1df9f9d84631dfef71b9dc (diff)
gemma : fix bfloat16 -> float16 conversion issue (#5810)
Diffstat (limited to 'convert-hf-to-gguf.py')
-rwxr-xr-xconvert-hf-to-gguf.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index ae30b2a7..d3e8ec1f 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -1811,16 +1811,15 @@ class GemmaModel(Model):
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data_torch in self.get_tensors():
- # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
- if name.endswith("norm.weight"):
- data_torch = data_torch + 1
-
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
+ # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
+ if name.endswith("norm.weight"):
+ data_torch = data_torch + 1
data = data_torch.squeeze().numpy()
# map tensor names