summaryrefslogtreecommitdiff
path: root/iqk_mul_mat.cpp
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-18 20:08:28 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:52 +0300
commit927e251a12fa287e13c6bd9667ee97d783486c09 (patch)
tree90ed8827fc28630f52e92d8b8ea664198a6f5829 /iqk_mul_mat.cpp
parent181fd9c56eaa64d0a92f9e8be7387f409cfa8745 (diff)
Bitnet(1.75 bpw): higher precision fp8 scale
Use 3 bits for the exponent and 5 bits for the mantissa. This makes PPL to be the same as fp16 (but the previous version with 4 bits for the exponent and mantissa was good enough for any practical purposes).
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r--iqk_mul_mat.cpp17
1 files changed, 5 insertions, 12 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index f4294d31..08f954e1 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -31,6 +31,7 @@
#include "ggml-impl.h"
#include "ggml-quants.h"
#include "iqk_mul_mat.h"
+#include "iqk-quantize.h"
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
@@ -1344,15 +1345,11 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
//auto step = bx / sizeof(block_iq1_bn);
const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
- typedef union { float f; uint32_t i; } scale_t;
-
- scale_t scale;
for (int ix = 0; ix < nrc_x; ++ix) {
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
- uint16_t u = x[0].extra & 0xff;
- scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+ float d1 = iq1bn_fp8_to_float(x[0].extra & 0xff);
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
@@ -1401,7 +1398,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, scale.f * hsum_float_8(accd[iy]));
+ info.store(ix, iy, d1 * hsum_float_8(accd[iy]));
}
}
@@ -4128,15 +4125,11 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
- typedef union { float f; uint32_t i; } scale_t;
-
- scale_t scale;
for (int ix = 0; ix < nrc_x; ++ix) {
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
- uint16_t u = x[0].extra & 0xff;
- scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
+ float d1 = iq1bn_fp8_to_float(x[0].extra & 0xff);
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_f32(0.f);
@@ -4186,7 +4179,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, scale.f * vaddvq_f32(accd[iy]));
+ info.store(ix, iy, d1 * vaddvq_f32(accd[iy]));
}
}