summaryrefslogtreecommitdiff
path: root/gguf-py
diff options
context:
space:
mode:
Diffstat (limited to 'gguf-py')
-rw-r--r--gguf-py/gguf/gguf.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py
index 6b7d6542..727b4e55 100644
--- a/gguf-py/gguf/gguf.py
+++ b/gguf-py/gguf/gguf.py
@@ -7,7 +7,7 @@ import shutil
import struct
import sys
import tempfile
-from enum import IntEnum, auto
+from enum import Enum, IntEnum, auto
from io import BufferedWriter
from pathlib import Path
from typing import IO, Any, BinaryIO, Callable, Sequence
@@ -53,9 +53,12 @@ KEY_ATTENTION_LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
# RoPE
-KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
-KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
-KEY_ROPE_SCALE_LINEAR = "{arch}.rope.scale_linear"
+KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
+KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
+KEY_ROPE_SCALING_TYPE = "{arch}.rope.scaling.type"
+KEY_ROPE_SCALING_FACTOR = "{arch}.rope.scaling.factor"
+KEY_ROPE_SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
+KEY_ROPE_SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
# tokenization
KEY_TOKENIZER_MODEL = "tokenizer.ggml.model"
@@ -577,6 +580,11 @@ class TokenType(IntEnum):
UNUSED = 5
BYTE = 6
+class RopeScalingType(Enum):
+ NONE = 'none'
+ LINEAR = 'linear'
+ YARN = 'yarn'
+
#
# implementation
#
@@ -948,8 +956,17 @@ class GGUFWriter:
def add_rope_freq_base(self, value: float):
self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
- def add_rope_scale_linear(self, value: float):
- self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value)
+ def add_rope_scaling_type(self, value: RopeScalingType):
+ self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value)
+
+ def add_rope_scaling_factor(self, value: float):
+ self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)
+
+ def add_rope_scaling_orig_ctx_len(self, value: int):
+ self.add_uint32(KEY_ROPE_SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
+
+ def add_rope_scaling_finetuned(self, value: bool):
+ self.add_bool(KEY_ROPE_SCALING_FINETUNED.format(arch=self.arch), value)
def add_tokenizer_model(self, model: str):
self.add_string(KEY_TOKENIZER_MODEL, model)