summaryrefslogtreecommitdiff
path: root/gguf-py
diff options
context:
space:
mode:
authorDouglas Hanley <thesecretaryofwar@gmail.com>2024-02-15 11:21:49 -0600
committerGitHub <noreply@github.com>2024-02-15 12:21:49 -0500
commit4524290e87b8e107cc2b56e1251751546f4b9051 (patch)
tree38d50aa2850bc2ecb53619fb9f03e0f91953a4c6 /gguf-py
parentc06e45d72983d9ace7b1535f7e7ea258d212169e (diff)
Use correct type of pooling for embedding models (#5500)
Use correct type of pooling for embedding models
Diffstat (limited to 'gguf-py')
-rw-r--r--gguf-py/gguf/constants.py8
-rw-r--r--gguf-py/gguf/gguf_writer.py5
2 files changed, 10 insertions, 3 deletions
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 9986ce9d..114a9a97 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -40,7 +40,7 @@ class Keys:
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count"
- POOLING_LAYER = "{arch}.pooling_layer"
+ POOLING_TYPE = "{arch}.pooling_type"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
@@ -561,6 +561,12 @@ class RopeScalingType(Enum):
YARN = 'yarn'
+class PoolingType(IntEnum):
+ NONE = 0
+ MEAN = 1
+ CLS = 2
+
+
class GGMLQuantizationType(IntEnum):
F32 = 0
F16 = 1
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index 26724bf9..e4681475 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -19,6 +19,7 @@ from .constants import (
GGUFValueType,
Keys,
RopeScalingType,
+ PoolingType,
TokenType,
)
@@ -360,8 +361,8 @@ class GGUFWriter:
def add_causal_attention(self, value: bool) -> None:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
- def add_pooling_layer(self, value: bool) -> None:
- self.add_bool(Keys.LLM.POOLING_LAYER.format(arch=self.arch), value)
+ def add_pooling_type(self, value: PoolingType) -> None:
+ self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value)
def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)