summaryrefslogtreecommitdiff
path: root/ggml-quants.h
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-01-11 20:39:39 +0100
committerGitHub <noreply@github.com>2024-01-11 21:39:39 +0200
commit49662cbed3e95f5976c070b85b9fd53fd577038d (patch)
treeb70cd0956715bc11696f6e47d26788e24c5112c4 /ggml-quants.h
parent3ba5b8ca8e6181a5c712c5b77595a29f1d3e2b97 (diff)
ggml : SOTA 2-bit quants (add IQ2_XS) (#4856)
* iq2_xs: basics * iq2_xs: this should have been in the basics * iq2_xs: CUDA and scalar CPU works * iq2_xs: WIP Metal * iq2_xs: Metal now works * iq2_xs: working, but dog slow, ARM_NEON dot product * iq2_xs: better ARM_NEON dot product We are now at 19.5 t/s for TG-128 and 61 t/s for PP-512 when running on the CPU. * iq2_xs: AVX2 dot product - 19.5 t/s * iq2_xs: faster AVX2 dit product 21.4 t/s for TG-128, 59.2 t/s for PP-512. The latter is 2x compared to the previous version. * iq2_xs: had forgotten to delete iq2-data.h * Add llama enum for IQ2_XS --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml-quants.h')
-rw-r--r--ggml-quants.h12
1 files changed, 12 insertions, 0 deletions
diff --git a/ggml-quants.h b/ggml-quants.h
index 8dd911d4..df5e7ae8 100644
--- a/ggml-quants.h
+++ b/ggml-quants.h
@@ -174,6 +174,14 @@ typedef struct {
} block_iq2_xxs;
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
+// 2.3125 bpw quants
+typedef struct {
+ ggml_fp16_t d;
+ uint16_t qs[QK_K/8];
+ uint8_t scales[QK_K/32];
+} block_iq2_xs;
+static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
+
// Quantization
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
@@ -189,6 +197,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
+void quantize_row_iq2_xs_reference (const float * restrict x, block_iq2_xs * restrict y, int k);
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
@@ -204,6 +213,7 @@ void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
+void quantize_row_iq2_xs (const float * restrict x, void * restrict y, int k);
// Dequantization
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
@@ -220,6 +230,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
+void dequantize_row_iq2_xs (const block_iq2_xs * restrict x, float * restrict y, int k);
// Dot product
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
@@ -234,3 +245,4 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx,
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);