summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-09-09 14:56:34 +0300
committerGitHub <noreply@github.com>2024-09-09 14:56:34 +0300
commit8c86231f9306c81dc291c4c4a16f88bbc7c97793 (patch)
treed49325de2775076e1f71ddf94667d0cd02db3cc5
parentbf4b19b474b78a6ddfa1f0fe19f76f3c7ac92030 (diff)
Adding IQ1_TN - 1.6875 bpw for TriLM ternary models (#44)
* Adding iq1_tn - 1.6875 bpw for TriLM ternary models * iq1_tn: NEON * iq1_tn: faster NEON * iq2_bn: improve performance on NEON We now get TG-128 = 100 t/s for Bitnet-3B-1.58b! * iq1_tn: improve AVX2 PP-512 goes to 533 t/s up from 455. TG-128 @ 2 threads goes to 16.6 t/s up from 14.2. However, we seem to have a bottleneck somewhere as TG saturates at 8 threads. * iq1_tn: improve Zen4 PP-512 goes to 485 t/s up from 352. With FA we get 545 t/s up from 380. TG-128 @ 1 thread goes to 12.4 t/s up from 10.4. However, we seem to have a bottleneck somewhere as TG saturates at 8 threads. * iq2_bn: improve on Zen4 We now get PP-512 = 614 t/s up from 542 t/s * iq2_bn: improve AVX2 implementation We now get PP-512 = 753 t/s up from 680 t/s. * Remove unnecessary barrier in ggml_compute_forward_mul_mat --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h38
-rw-r--r--ggml/src/ggml-common.h4
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c35
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp265
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp74
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
-rw-r--r--include/llama.h18
-rw-r--r--src/llama.cpp10
10 files changed, 304 insertions, 148 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 9a08d625..c6153e45 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -28,6 +28,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", },
{ "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", },
+ { "IQ1_TN", LLAMA_FTYPE_MOSTLY_IQ1_TN, " 1.69 bpw quantization (TriLM)", },
{ "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.06 bpw quantization (TriLM)", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index ab6d172d..5b46a70d 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -391,15 +391,17 @@ extern "C" {
GGML_TYPE_Q4_0_4_4 = 31,
GGML_TYPE_Q4_0_4_8 = 32,
GGML_TYPE_Q4_0_8_8 = 33,
- GGML_TYPE_IQ1_BN = 34,
- GGML_TYPE_IQ2_BN = 35,
- GGML_TYPE_Q8_K64 = 36,
- GGML_TYPE_IQ2_K = 37,
- GGML_TYPE_IQ3_K = 38,
- GGML_TYPE_IQ4_K = 39,
- GGML_TYPE_IQ5_K = 40,
- GGML_TYPE_IQ6_K = 41,
- GGML_TYPE_IQ2_TN = 42,
+ //
+ GGML_TYPE_IQ1_BN = 134,
+ GGML_TYPE_IQ2_BN = 135,
+ GGML_TYPE_Q8_K64 = 136,
+ GGML_TYPE_IQ2_K = 137,
+ GGML_TYPE_IQ3_K = 138,
+ GGML_TYPE_IQ4_K = 139,
+ GGML_TYPE_IQ5_K = 140,
+ GGML_TYPE_IQ6_K = 141,
+ GGML_TYPE_IQ2_TN = 142,
+ GGML_TYPE_IQ1_TN = 143,
GGML_TYPE_COUNT,
};
@@ -444,14 +446,16 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
- GGML_FTYPE_MOSTLY_IQ1_BN = 28, // except 1d tensors
- GGML_FTYPE_MOSTLY_IQ2_BN = 29, // except 1d tensors
- GGML_FTYPE_MOSTLY_IQ2_K = 30, // except 1d tensors
- GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors
- GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors
- GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors
- GGML_FTYPE_MOSTLY_IQ6_K = 34, // except 1d tensors
- GGML_FTYPE_MOSTLY_IQ2_TN = 35, // except 1d tensors
+ //
+ GGML_FTYPE_MOSTLY_IQ1_BN = 128, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_BN = 129, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_K = 130, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ3_K = 131, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ4_K = 132, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ5_K = 133, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ6_K = 134, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_TN = 135, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ1_TN = 136, // except 1d tensors
};
// available tensor operations:
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 57fdeb82..f1a34f7a 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -436,6 +436,10 @@ static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/paddi
// TriLM - implemented as 2.0625 bpw
//
typedef struct {
+ uint8_t qs[54];
+} block_iq1_tn;
+static_assert(sizeof(block_iq1_tn) == 54, "wrong iq1_tn block size/padding");
+typedef struct {
ggml_half d;
uint8_t qs[QK_K/4];
} block_iq2_tn;
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 981fb54b..a9a25761 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15015,6 +15015,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ5_K: break;
case GGML_TYPE_IQ6_K: break;
case GGML_TYPE_IQ2_TN: break;
+ case GGML_TYPE_IQ1_TN: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
{
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index d562002e..4fdf9c18 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -985,6 +985,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
+ [GGML_TYPE_IQ1_TN] = {
+ .type_name = "iq1_tn",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq1_tn),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq1_tn,
+ .from_float = quantize_row_iq1_tn,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq1_tn_ref,
+ .vec_dot = vec_dot_iq1_tn_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K64,
+ .nrows = 1,
+ },
[GGML_TYPE_IQ4_NL] = {
.type_name = "iq4_nl",
.blck_size = QK4_NL,
@@ -3705,6 +3717,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break;
case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break;
case GGML_FTYPE_MOSTLY_IQ2_TN: wtype = GGML_TYPE_IQ2_TN; break;
+ case GGML_FTYPE_MOSTLY_IQ1_TN: wtype = GGML_TYPE_IQ1_TN; break;
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break;
@@ -10133,6 +10146,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_TN:
+ case GGML_TYPE_IQ1_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -10519,6 +10533,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_TN:
+ case GGML_TYPE_IQ1_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -10655,6 +10670,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_TN:
+ case GGML_TYPE_IQ1_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -13078,14 +13094,14 @@ UseGgmlGemm1:;
int64_t t2 = ggml_time_us();
if (ith == 0) printf("quantize(%s): %d us\n", dst->name, (int)(t2 - t1));
#endif
- }
- if (ith == 0) {
- // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
- atomic_store(&params->shared->current_chunk, nth);
- }
+ if (ith == 0) {
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
+ atomic_store(&params->shared->current_chunk, nth);
+ }
- ggml_barrier(params->shared);
+ ggml_barrier(params->shared);
+ }
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
@@ -13104,8 +13120,6 @@ UseGgmlGemm1:;
IQK_MulMat_Not_Available2:;
#endif
- ggml_barrier(params->shared);
-
#if GGML_USE_LLAMAFILE
if (src1->type != vec_dot_type) {
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -13692,6 +13706,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_TN:
+ case GGML_TYPE_IQ1_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -14068,6 +14083,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_TN:
+ case GGML_TYPE_IQ1_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -14338,6 +14354,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_TN:
+ case GGML_TYPE_IQ1_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -14935,6 +14952,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_TN:
+ case GGML_TYPE_IQ1_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -21722,6 +21740,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_TN: result = quantize_iq2_tn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ1_TN: result = quantize_iq1_tn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 55366ab1..b5e3cba3 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -2078,15 +2078,16 @@ template <int nrc> struct Q8_K64 {
Q8_K64(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) {
const float * dptr = (const float *)info.src1_row(iy);
- std::memcpy(d + 4*iy, dptr, 4*sizeof(float));
- y[iy] = (const int8_t *)(dptr + 4);
+ std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
+ y[iy] = (const int8_t *)(dptr + 8);
}
}
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); }
- inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 4*iy); }
+ inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); }
+ inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); }
- float d[4*nrc_y];
+ float d[8*nrc_y];
const int8_t * y[nrc_y];
};
@@ -2121,17 +2122,17 @@ struct DequantizerIQ1BN {
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3);
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3);
#ifdef HAVE_FANCY_SIMD
- v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
- v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
+ v1 = _mm256_permutex2var_epi8(val1, bmask, val2);
+ v2 = _mm256_permutex2var_epi8(val3, bmask, val4);
#else
- v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
- v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
+ v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216);
+ v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216);
#endif
}
};
-template <int nrc_y>
+template <int nrc_y, bool is_iq1_tn>
IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
@@ -2143,11 +2144,18 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
const auto m1_16 = _mm256_set1_epi16(1);
#endif
- const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
+ const block_iq1_bn * x;
+ const char * cx0 = (const char *)vx;
+ float scale;
for (int ix = 0; ix < nrc_x; ++ix) {
- x = (const block_iq1_bn *)((const char *)vx + ix*bx);
+ const char * cx = cx0 + ix*bx;
+ if constexpr (is_iq1_tn) {
+ scale = GGML_FP16_TO_FP32(*(const ggml_half *)cx);
+ cx += sizeof(ggml_half);
+ }
+ x = (const block_iq1_bn *)cx;
if constexpr (nrc_y == 1) {
__m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
@@ -2155,17 +2163,13 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]);
- auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]);
- auto dot3 = _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2]);
- auto dot4 = _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]);
- acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, deq.m1_8, dot1), deq.m1_8, dot2);
- acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, deq.m1_8, dot3), deq.m1_8, dot4);
+ acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1));
+ acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3));
#else
- auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
- _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])));
- auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
- _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3])));
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
+ _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1));
acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2));
#endif
@@ -2183,17 +2187,16 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
for (int iy = 0; iy < nrc_y; ++iy) {
#if defined __AVX512VNNI__ && defined __AVX512VL__
- auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
- auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
- auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
- auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
- accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
- accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
+ val[0], q8.load_quants(iy, i, 0)),
+ val[1], q8.load_quants(iy, i, 1)),
+ val[2], q8.load_quants(iy, i, 2)),
+ val[3], q8.load_quants(iy, i, 3));
#else
- auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])),
- _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
- auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])),
- _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1)));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
+ _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)));
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
#endif
@@ -2204,13 +2207,12 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
if (i < nb) {
deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
- auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2);
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
+ val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1));
#else
- auto dot = _mm256_madd_epi16(m1_16,
- _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)));
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1))));
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
#endif
}
@@ -2219,8 +2221,12 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
for (int iy = 0; iy < nrc_y; ++iy) {
auto vd = q8.scale(iy);
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
- auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi));
- info.store(ix, iy, hsum_float_4(sumf));
+ auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
+ if constexpr (is_iq1_tn) {
+ info.store(ix, iy, scale*hsum_float_4(sumf));
+ } else {
+ info.store(ix, iy, hsum_float_4(sumf));
+ }
}
}
@@ -2236,8 +2242,8 @@ struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn> {
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
}
IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const {
- val[0] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
- val[1] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask3), mf_8);
+ val[0] = _mm256_and_si256(q2_1, mask2);
+ val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2);
}
IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const {
auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
@@ -2270,15 +2276,15 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
for (int i = 0; i < nb/2; ++i) {
deq.prepare4(i, val);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
- deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]));
- acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
- deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]));
+ acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)),
+ val[1], q8.load_quants(0, i, 1));
+ acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)),
+ val[3], q8.load_quants(0, i, 3));
#else
- auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
- _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])));
- auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
- _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3])));
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
+ _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1));
acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2));
#endif
@@ -2292,18 +2298,17 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
for (int i = 0; i < nb/2; ++i) {
deq.prepare4(i, val);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
- auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
- auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
- auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
- accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
+ val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)),
+ val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3));
#else
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
- _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)),
- _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4))));
- accd[iy] = i > 0 ? _mm256_add_epi32(dot, accd[iy]) : dot;
+ _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))),
+ _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
+ _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)))));
+ accd[iy] = _mm256_add_epi32(dot, accd[iy]);
#endif
}
}
@@ -2312,13 +2317,13 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
if (i < nb) {
deq.prepare2(i, val);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
- auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2);
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)),
+ val[1], q8.load_quants(iy, i/2, 1));
#else
- dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)));
- accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
+ _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 0))));
+ accd[iy] = _mm256_add_epi32(dot, accd[iy]);
#endif
}
}
@@ -2326,7 +2331,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
for (int iy = 0; iy < nrc_y; ++iy) {
auto vd = q8.scale(iy);
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
- auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi));
+ auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
info.store(ix, iy, hsum_float_4(sumf));
}
}
@@ -3733,14 +3738,26 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
break;
case GGML_TYPE_IQ1_BN:
assert (ne00 % QK_IQ1BN == 0);
- mm.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
- mm.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
- mm.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
- mm.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
- mm.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
- mm.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
- mm.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
- mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
+ mm.funcs[0] = mul_mat_iq1bn_q8_K64<1, false>;
+ mm.funcs[1] = mul_mat_iq1bn_q8_K64<2, false>;
+ mm.funcs[2] = mul_mat_iq1bn_q8_K64<3, false>;
+ mm.funcs[3] = mul_mat_iq1bn_q8_K64<4, false>;
+ mm.funcs[4] = mul_mat_iq1bn_q8_K64<5, false>;
+ mm.funcs[5] = mul_mat_iq1bn_q8_K64<6, false>;
+ mm.funcs[6] = mul_mat_iq1bn_q8_K64<7, false>;
+ mm.funcs[7] = mul_mat_iq1bn_q8_K64<8, false>;
+ expected_typeB = GGML_TYPE_Q8_K64;
+ break;
+ case GGML_TYPE_IQ1_TN:
+ assert (ne00 % QK_IQ1BN == 0);
+ mm.funcs[0] = mul_mat_iq1bn_q8_K64<1, true>;
+ mm.funcs[1] = mul_mat_iq1bn_q8_K64<2, true>;
+ mm.funcs[2] = mul_mat_iq1bn_q8_K64<3, true>;
+ mm.funcs[3] = mul_mat_iq1bn_q8_K64<4, true>;
+ mm.funcs[4] = mul_mat_iq1bn_q8_K64<5, true>;
+ mm.funcs[5] = mul_mat_iq1bn_q8_K64<6, true>;
+ mm.funcs[6] = mul_mat_iq1bn_q8_K64<7, true>;
+ mm.funcs[7] = mul_mat_iq1bn_q8_K64<8, true>;
expected_typeB = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN:
@@ -5825,16 +5842,17 @@ template <int nrc> struct Q8_K64 {
Q8_K64(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
- std::memcpy(d + 4*iy, dptr, 4*sizeof(float));
- y[iy] = (const int8_t *)(dptr + 4);
+ std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
+ y[iy] = (const int8_t *)(dptr + 8);
}
}
inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); }
inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); }
- inline float32x4_t scale(int iy) const { return vld1q_f32(d + 4*iy); }
+ inline float32x4_t scale(int iy) const { return vld1q_f32(d + 8*iy); }
+ inline float32x4_t minus(int iy) const { return vld1q_f32(d + 8*iy + 4); }
- float d[4*nrc_y];
+ float d[8*nrc_y];
const int8_t * y[nrc_y];
};
@@ -5866,9 +5884,17 @@ struct DequantizerIQ1BN {
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
}
}
+
+ IQK_ALWAYS_INLINE void prepare_iq1bn_quants_nosub(const block_iq1_bn * x, int8x16x4_t& v) const {
+ auto data = vld1q_u8((const uint8_t *)x);
+ for (int k = 0; k < 4; ++k) {
+ auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
+ v.val[k] = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6));
+ }
+ }
};
-template <int nrc_y>
+template <int nrc_y, bool is_iq1_tn>
static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
@@ -5878,19 +5904,25 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
int32x4_t accd[nrc_y];
int8x16x4_t v1, v2;
- const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
+ float scale;
for (int ix = 0; ix < nrc_x; ++ix) {
- x = (const block_iq1_bn *)((const char *)vx + ix*bx);
+ const char * cx = ((const char *)vx + ix*bx);
+ if constexpr (is_iq1_tn) {
+ scale = GGML_FP16_TO_FP32(*(const ggml_half *)cx);
+ cx += sizeof(ggml_half);
+ }
+
+ const block_iq1_bn * x = (const block_iq1_bn *)cx;
if constexpr (nrc_y == 1) {
int32x4_t acc[4] = {};
for (int i = 0; i < nb/2; ++i) {
- deq.prepare_iq1bn_quants(x+2*i+0, v1);
+ deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
auto q = q8.load_quants64(0, i, 0);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
- deq.prepare_iq1bn_quants(x+2*i+1, v2);
+ deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
q = q8.load_quants64(0, i, 1);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]);
}
@@ -5902,8 +5934,8 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int i = 0; i < nb/2; ++i) {
- deq.prepare_iq1bn_quants(x+2*i+0, v1);
- deq.prepare_iq1bn_quants(x+2*i+1, v2);
+ deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
+ deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i, 0);
@@ -5919,7 +5951,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
int i = 2*(nb/2);
if (i < nb) {
- deq.prepare_iq1bn_quants(x+i, v1);
+ deq.prepare_iq1bn_quants_nosub(x+i, v1);
if constexpr (nrc_y == 1) {
auto q = q8.load_quants(0, i/2, 0);
for (int j = 0; j < 4; ++j) {
@@ -5936,7 +5968,11 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
+ if constexpr (is_iq1_tn) {
+ info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
+ } else {
+ info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
+ }
}
}
@@ -5964,10 +6000,10 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int j = 0; j < 2; ++j) {
auto q = q8.load_quants64(0, i, j);
auto q2bits = vld1q_u8(x[2*i+j].qs);
- v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
- v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
- v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
- v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ v1.val[0] = vandq_s8(q2bits, mask2);
+ v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
+ v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
+ v1.val[3] = vshrq_n_u8(q2bits, 6);
acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]);
acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]);
acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]);
@@ -5980,15 +6016,15 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
for (int i = 0; i < nb/2; ++i) {
auto q2bits = vld1q_u8(x[2*i+0].qs);
- v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
- v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
- v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
- v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ v1.val[0] = vandq_s8(q2bits, mask2);
+ v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
+ v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
+ v1.val[3] = vshrq_n_u8(q2bits, 6);
q2bits = vld1q_u8(x[2*i+1].qs);
- v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
- v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
- v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
- v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ v2.val[0] = vandq_s8(q2bits, mask2);
+ v2.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
+ v2.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
+ v2.val[3] = vshrq_n_u8(q2bits, 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
@@ -6005,10 +6041,10 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
if (i < nb) {
auto q2bits = vld1q_u8(x[i].qs);
int8x16x4_t v1;
- v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
- v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
- v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
- v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ v1.val[0] = vandq_s8(q2bits, mask2);
+ v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
+ v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
+ v1.val[3] = vshrq_n_u8(q2bits, 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i/2, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
@@ -6018,7 +6054,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
+ info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
}
}
}
@@ -6133,14 +6169,25 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
MulMat::set_functions<DequantizerIQ3S>(m);
break;
case GGML_TYPE_IQ1_BN:
- m.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
- m.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
- m.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
- m.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
- m.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
- m.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
- m.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
- m.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
+ m.funcs[0] = mul_mat_iq1bn_q8_K64<1, false>;
+ m.funcs[1] = mul_mat_iq1bn_q8_K64<2, false>;
+ m.funcs[2] = mul_mat_iq1bn_q8_K64<3, false>;
+ m.funcs[3] = mul_mat_iq1bn_q8_K64<4, false>;
+ m.funcs[4] = mul_mat_iq1bn_q8_K64<5, false>;
+ m.funcs[5] = mul_mat_iq1bn_q8_K64<6, false>;
+ m.funcs[6] = mul_mat_iq1bn_q8_K64<7, false>;
+ m.funcs[7] = mul_mat_iq1bn_q8_K64<8, false>;
+ expected_Btype = GGML_TYPE_Q8_K64;
+ break;
+ case GGML_TYPE_IQ1_TN:
+ m.funcs[0] = mul_mat_iq1bn_q8_K64<1, true>;
+ m.funcs[1] = mul_mat_iq1bn_q8_K64<2, true>;
+ m.funcs[2] = mul_mat_iq1bn_q8_K64<3, true>;
+ m.funcs[3] = mul_mat_iq1bn_q8_K64<4, true>;
+ m.funcs[4] = mul_mat_iq1bn_q8_K64<5, true>;
+ m.funcs[5] = mul_mat_iq1bn_q8_K64<6, true>;
+ m.funcs[6] = mul_mat_iq1bn_q8_K64<7, true>;
+ m.funcs[7] = mul_mat_iq1bn_q8_K64<8, true>;
expected_Btype = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN:
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 0968becf..9b39a490 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -119,6 +119,54 @@ void quantize_row_iq1_bn(const float * x, void * y, int64_t k) {
quantize_iq1_bn(x, y, 1, k, nullptr);
}
+void quantize_row_iq1_tn_ref(const float * x, block_iq1_tn * y, int64_t k) {
+ quantize_iq1_tn(x, (void *)y, 1, k, nullptr);
+}
+
+void quantize_row_iq1_tn(const float * x, void * y, int64_t k) {
+ quantize_iq1_tn(x, y, 1, k, nullptr);
+}
+
+size_t quantize_iq1_tn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(n_per_row >= 2*QK_K); // so we have space for the scale
+ int nblock = n_per_row/QK_IQ1BN;
+ float tmp[QK_IQ1BN];
+ char * qrow = (char *)dst;
+ auto row_size = ggml_row_size(GGML_TYPE_IQ1_TN, n_per_row);
+ IQ1BNQuantizer iq1bn;
+ for (int row = 0; row < nrows; ++row) {
+ float max = fabsf(src[0]);
+ for (int j = 1; j < n_per_row; ++j) max = std::max(max, fabsf(src[j]));
+ if (!(max > 0)) printf("%s: found max = %g?\n", __func__, max);
+ //GGML_ASSERT(max > 0);
+ *(ggml_half *)qrow = GGML_FP32_TO_FP16(max);
+ block_iq1_bn * y = (block_iq1_bn *)(qrow + sizeof(ggml_half));
+ const float * xb = src;
+ for (int ib = 0; ib < nblock; ++ib) {
+ for (int j = 0; j < QK_IQ1BN; ++j) tmp[j] = xb[j] < -0.5f*max ? -1 : xb[j] <= 0.5f*max ? 0 : 1;
+ iq1bn.quantize_one_row_1bn(tmp, y, QK_IQ1BN, imatrix);
+ ++y;
+ xb += QK_IQ1BN;
+ }
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrows*row_size;
+}
+
+void dequantize_row_iq1_tn(const block_iq1_tn * x, float * y, int64_t k) {
+ float scale = GGML_FP16_TO_FP32(*(const ggml_half *)x);
+ const block_iq1_bn * iq1bn = (const block_iq1_bn *)((const char *)x + sizeof(ggml_half));
+ dequantize_row_iq1_bn(iq1bn, y, k);
+ for (int j = 0; j < int(k); ++j) y[j] *= scale;
+}
+
+void vec_dot_iq1_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+ float scale = GGML_FP16_TO_FP32(*(const ggml_half *)vx);
+ ggml_vec_dot_iq1_bn_q8_K64(n, s, bs, (const void *)((const char *)vx + sizeof(ggml_half)), bx, vy, by, nrc);
+ *s *= scale;
+}
+
void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
assert(k%QK_IQ1BN == 0);
int nblock = k / QK_IQ1BN;
@@ -331,8 +379,10 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
+ GGML_ASSERT(k >= 8*QK_IQ1BN);
+
float * dptr = (float *)y;
- auto qs = (int8_t *)(dptr + 4);
+ auto qs = (int8_t *)(dptr + 8);
#ifdef __ARM_NEON
static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60};
auto shuffle = vld1q_u8(k_shuffle);
@@ -351,16 +401,22 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
vid[i] = vdupq_n_f32(id);
}
int8x16x4_t q;
+ int32x4_t qsum = {};
+ const int8x16_t m1 = vdupq_n_s8(1);
for (int j = 0; j < k; j += 16) {
for (int i = 0; i < 4; ++i) {
auto val = vld1q_f32(x + j + 4*i);
val = vmulq_f32(vid[i], val);
- q.val[i] = vreinterpretq_s8_s32(vcvtnq_s32_f32(val));
+ auto ival = vcvtnq_s32_f32(val);
+ q.val[i] = vreinterpretq_s8_s32(ival);
}
auto qi = vqtbl4q_s8(q, shuffle);
+ qsum = ggml_vdotq_s32(qsum, qi, m1);
vst1q_s8(qs, qi);
qs += 16;
}
+ auto sumf = vmulq_f32(vld1q_f32(dptr), vcvtq_f32_s32(qsum));
+ vst1q_f32(dptr + 4, sumf);
#elif defined __AVX__
__m128 max[4] = {};
__m128 sign_bit = _mm_set1_ps(-0.f);
@@ -381,6 +437,9 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
vid[i] = _mm_set1_ps(id);
}
__m128i q[4];
+ __m128i sums = _mm_setzero_si128();
+ __m128i m1_8 = _mm_set1_epi8(1);
+ __m128i m1_16 = _mm_set1_epi16(1);
for (int j = 0; j < k; j += 16) {
for (int i = 0; i < 4; ++i) {
auto val = _mm_loadu_ps(x + j + 4*i);
@@ -390,9 +449,13 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
auto q1 = _mm_packs_epi32(q[0], q[1]);
auto q2 = _mm_packs_epi32(q[2], q[3]);
auto qi = _mm_packs_epi16(q1, q2);
+ auto aux = _mm_maddubs_epi16(m1_8, qi);
+ sums = _mm_add_epi32(sums, _mm_madd_epi16(m1_16, aux));
_mm_storeu_si128((__m128i *)qs, qi);
qs += 16;
}
+ auto minus = _mm_mul_ps(_mm_loadu_ps(dptr), _mm_cvtepi32_ps(sums));
+ _mm_storeu_ps(dptr + 4, minus);
#else
float aux[4] = {0.f, 0.f, 0.f, 0.f};
for (int j = 0; j < k; j += 16) {
@@ -407,11 +470,16 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
dptr[i] = aux[i]/127;
aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f;
}
+ int32_t sum[4] = {};
for (int j = 0; j < k; j += 16) {
for (int i = 0; i < 4; ++i) {
- for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]);
+ for (int l = 0; l < 4; ++l) {
+ qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]);
+ sum[i] += qs[j+4*i+l];
+ }
}
}
+ for (int i = 0; i < 4; ++i) dptr[4+i] = dptr[i]*sum[i];
#endif
}
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index d7a0748f..e5c16fc9 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -49,6 +49,12 @@ size_t quantize_iq2_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst
void dequantize_row_iq2_tn(const block_iq2_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_iq2_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_iq1_tn_ref(const float * GGML_RESTRICT x, block_iq1_tn * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq1_tn(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq1_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq1_tn(const block_iq1_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq1_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
#ifdef __cplusplus
diff --git a/include/llama.h b/include/llama.h
index a9af4c48..02d94b6c 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -166,14 +166,16 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
- LLAMA_FTYPE_MOSTLY_IQ1_BN = 36, // except 1d tensors
- LLAMA_FTYPE_MOSTLY_IQ2_BN = 37, // except 1d tensors
- LLAMA_FTYPE_MOSTLY_IQ2_K = 38, // except 1d tensors
- LLAMA_FTYPE_MOSTLY_IQ3_K = 39, // except 1d tensors
- LLAMA_FTYPE_MOSTLY_IQ4_K = 40, // except 1d tensors
- LLAMA_FTYPE_MOSTLY_IQ5_K = 41, // except 1d tensors
- LLAMA_FTYPE_MOSTLY_IQ6_K = 42, // except 1d tensors
- LLAMA_FTYPE_MOSTLY_IQ2_TN = 43, // except 1d tensors
+ //
+ LLAMA_FTYPE_MOSTLY_IQ1_BN = 136, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ2_BN = 137, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ2_K = 138, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ3_K = 139, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ4_K = 140, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ5_K = 141, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ6_K = 142, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ2_TN = 143, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ1_TN = 144, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
diff --git a/src/llama.cpp b/src/llama.cpp
index 768aafa7..bb9b6848 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3788,6 +3788,7 @@ struct llama_model_loader {
case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break;
case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break;
case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break;
+ case GGML_TYPE_IQ1_TN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_TN; break;
case GGML_TYPE_IQ2_TN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_TN; break;
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
@@ -4497,8 +4498,9 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ5_K: return "IQ5_K - 5.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ6_K: return "IQ6_K - 6.6 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_BN: return "IQ1_BN - 1.625 bpw Bitnet";
+ case LLAMA_FTYPE_MOSTLY_IQ1_TN: return "IQ1_TN - 1.6875 bpw TriLM";
case LLAMA_FTYPE_MOSTLY_IQ2_BN: return "IQ2_BN - 2.00 bpw Bitnet";
- case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQT_BN - 2.06 bpw TriLM";
+ case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQ2_TN - 2.06 bpw TriLM";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4";
@@ -15644,7 +15646,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN) {
new_type = GGML_TYPE_IQ4_NL;
}
- else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_TN) {
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_TN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_TN) {
new_type = GGML_TYPE_Q4_K;
}
else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 ||
@@ -15856,7 +15858,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S ||
new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K ||
new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_IQ2_TN ||
- new_type == GGML_TYPE_IQ6_K) {
+ new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ1_TN) {
int nx = tensor->ne[0];
int ny = tensor->ne[1];
if (nx % QK_K != 0) {
@@ -15881,6 +15883,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_TN:
case GGML_TYPE_IQ2_TN:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
@@ -15991,6 +15994,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break;
case LLAMA_FTYPE_MOSTLY_IQ1_BN: default_type = GGML_TYPE_IQ1_BN; break;
case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break;
+ case LLAMA_FTYPE_MOSTLY_IQ1_TN: default_type = GGML_TYPE_IQ1_TN; break;
case LLAMA_FTYPE_MOSTLY_IQ2_TN: default_type = GGML_TYPE_IQ2_TN; break;
case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;