summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_quantize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_quantize.cpp')
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp44
1 files changed, 35 insertions, 9 deletions
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 0ea608bc..66cf32bc 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -2469,8 +2469,8 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
return nrows * nblock * sizeof(block_iq6_k);
}
-
-void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
+template <bool is_K32>
+void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
block_q8_K * y = (block_q8_K *)vy;
@@ -2505,8 +2505,14 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
__m256i i1 = _mm256_cvtps_epi32(v1);
__m256i i2 = _mm256_cvtps_epi32(v2);
__m256i i3 = _mm256_cvtps_epi32(v3);
- y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1));
- y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3));
+ if constexpr (is_K32) {
+ int bsum = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
+ auto bs = (float *)y[i].bsums;
+ bs[ib] = d*bsum;
+ } else {
+ y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1));
+ y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3));
+ }
i0 = _mm256_packs_epi32( i0, i1 );
i2 = _mm256_packs_epi32( i2, i3 );
i0 = _mm256_packs_epi16( i0, i2 );
@@ -2539,12 +2545,24 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
int v = nearest_int(iscale*x[j]);
y[i].qs[j] = MIN(127, v);
}
- for (int j = 0; j < QK_K/16; ++j) {
- int sum = 0;
- for (int ii = 0; ii < 16; ++ii) {
- sum += y[i].qs[j*16 + ii];
+ if constexpr (is_K32) {
+ auto bs = (float *)y[i].bsums;
+ float d = 1/iscale;
+ for (int j = 0; j < QK_K/32; ++j) {
+ int sum = 0;
+ for (int ii = 0; ii < 32; ++ii) {
+ sum += y[i].qs[j*32 + ii];
+ }
+ bs[j] = d*sum;
+ }
+ } else {
+ for (int j = 0; j < QK_K/16; ++j) {
+ int sum = 0;
+ for (int ii = 0; ii < 16; ++ii) {
+ sum += y[i].qs[j*16 + ii];
+ }
+ y[i].bsums[j] = sum;
}
- y[i].bsums[j] = sum;
}
y[i].d = 1/iscale;
x += QK_K;
@@ -2553,6 +2571,14 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
}
+void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
+ iqk_quantize_row_q8_K_T<false>(x, vy, k);
+}
+
+void quantize_row_q8_K32(const float * x, void * vy, int64_t k) {
+ iqk_quantize_row_q8_K_T<true>(x, vy, k);
+}
+
namespace {
static void quantize_row_iq4_k_impl_bs128(const int super_block_size, const int block_size,
int n_per_row, const float * x, char * cy,