diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-09-09 14:56:34 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-09 14:56:34 +0300 |
commit | 8c86231f9306c81dc291c4c4a16f88bbc7c97793 (patch) | |
tree | d49325de2775076e1f71ddf94667d0cd02db3cc5 | |
parent | bf4b19b474b78a6ddfa1f0fe19f76f3c7ac92030 (diff) |
Adding IQ1_TN - 1.6875 bpw for TriLM ternary models (#44)
* Adding iq1_tn - 1.6875 bpw for TriLM ternary models
* iq1_tn: NEON
* iq1_tn: faster NEON
* iq2_bn: improve performance on NEON
We now get TG-128 = 100 t/s for Bitnet-3B-1.58b!
* iq1_tn: improve AVX2
PP-512 goes to 533 t/s up from 455.
TG-128 @ 2 threads goes to 16.6 t/s up from 14.2.
However, we seem to have a bottleneck somewhere as
TG saturates at 8 threads.
* iq1_tn: improve Zen4
PP-512 goes to 485 t/s up from 352. With FA we get 545 t/s up from 380.
TG-128 @ 1 thread goes to 12.4 t/s up from 10.4.
However, we seem to have a bottleneck somewhere as
TG saturates at 8 threads.
* iq2_bn: improve on Zen4
We now get PP-512 = 614 t/s up from 542 t/s
* iq2_bn: improve AVX2 implementation
We now get PP-512 = 753 t/s up from 680 t/s.
* Remove unnecessary barrier in ggml_compute_forward_mul_mat
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | examples/quantize/quantize.cpp | 1 | ||||
-rw-r--r-- | ggml/include/ggml.h | 38 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 4 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 1 | ||||
-rw-r--r-- | ggml/src/ggml.c | 35 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 265 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 74 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 6 | ||||
-rw-r--r-- | include/llama.h | 18 | ||||
-rw-r--r-- | src/llama.cpp | 10 |
10 files changed, 304 insertions, 148 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 9a08d625..c6153e45 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -28,6 +28,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", }, { "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", }, + { "IQ1_TN", LLAMA_FTYPE_MOSTLY_IQ1_TN, " 1.69 bpw quantization (TriLM)", }, { "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.06 bpw quantization (TriLM)", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index ab6d172d..5b46a70d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -391,15 +391,17 @@ extern "C" { GGML_TYPE_Q4_0_4_4 = 31, GGML_TYPE_Q4_0_4_8 = 32, GGML_TYPE_Q4_0_8_8 = 33, - GGML_TYPE_IQ1_BN = 34, - GGML_TYPE_IQ2_BN = 35, - GGML_TYPE_Q8_K64 = 36, - GGML_TYPE_IQ2_K = 37, - GGML_TYPE_IQ3_K = 38, - GGML_TYPE_IQ4_K = 39, - GGML_TYPE_IQ5_K = 40, - GGML_TYPE_IQ6_K = 41, - GGML_TYPE_IQ2_TN = 42, + // + GGML_TYPE_IQ1_BN = 134, + GGML_TYPE_IQ2_BN = 135, + GGML_TYPE_Q8_K64 = 136, + GGML_TYPE_IQ2_K = 137, + GGML_TYPE_IQ3_K = 138, + GGML_TYPE_IQ4_K = 139, + GGML_TYPE_IQ5_K = 140, + GGML_TYPE_IQ6_K = 141, + GGML_TYPE_IQ2_TN = 142, + GGML_TYPE_IQ1_TN = 143, GGML_TYPE_COUNT, }; @@ -444,14 +446,16 @@ extern "C" { GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ1_BN = 28, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ2_BN = 29, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ2_K = 30, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ6_K = 34, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ2_TN = 35, // except 1d tensors + // + GGML_FTYPE_MOSTLY_IQ1_BN = 128, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_BN = 129, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_K = 130, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ3_K = 131, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_K = 132, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ5_K = 133, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ6_K = 134, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_TN = 135, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ1_TN = 136, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 57fdeb82..f1a34f7a 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -436,6 +436,10 @@ static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/paddi // TriLM - implemented as 2.0625 bpw // typedef struct { + uint8_t qs[54]; +} block_iq1_tn; +static_assert(sizeof(block_iq1_tn) == 54, "wrong iq1_tn block size/padding"); +typedef struct { ggml_half d; uint8_t qs[QK_K/4]; } block_iq2_tn; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 981fb54b..a9a25761 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15015,6 +15015,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ5_K: break; case GGML_TYPE_IQ6_K: break; case GGML_TYPE_IQ2_TN: break; + case GGML_TYPE_IQ1_TN: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d562002e..4fdf9c18 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -985,6 +985,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_IQ1_TN] = { + .type_name = "iq1_tn", + .blck_size = QK_K, + .type_size = sizeof(block_iq1_tn), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq1_tn, + .from_float = quantize_row_iq1_tn, + .from_float_ref = (ggml_from_float_t)quantize_row_iq1_tn_ref, + .vec_dot = vec_dot_iq1_tn_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K64, + .nrows = 1, + }, [GGML_TYPE_IQ4_NL] = { .type_name = "iq4_nl", .blck_size = QK4_NL, @@ -3705,6 +3717,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break; case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break; case GGML_FTYPE_MOSTLY_IQ2_TN: wtype = GGML_TYPE_IQ2_TN; break; + case GGML_FTYPE_MOSTLY_IQ1_TN: wtype = GGML_TYPE_IQ1_TN; break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break; @@ -10133,6 +10146,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ2_TN: + case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -10519,6 +10533,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ2_TN: + case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -10655,6 +10670,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ2_TN: + case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -13078,14 +13094,14 @@ UseGgmlGemm1:; int64_t t2 = ggml_time_us(); if (ith == 0) printf("quantize(%s): %d us\n", dst->name, (int)(t2 - t1)); #endif - } - if (ith == 0) { - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - atomic_store(¶ms->shared->current_chunk, nth); - } + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + atomic_store(¶ms->shared->current_chunk, nth); + } - ggml_barrier(params->shared); + ggml_barrier(params->shared); + } const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; @@ -13104,8 +13120,6 @@ UseGgmlGemm1:; IQK_MulMat_Not_Available2:; #endif - ggml_barrier(params->shared); - #if GGML_USE_LLAMAFILE if (src1->type != vec_dot_type) { const size_t row_size = ggml_row_size(vec_dot_type, ne10); @@ -13692,6 +13706,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ2_TN: + case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -14068,6 +14083,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ2_TN: + case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -14338,6 +14354,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ2_TN: + case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -14935,6 +14952,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ2_TN: + case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_K: @@ -21722,6 +21740,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_TN: result = quantize_iq2_tn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ1_TN: result = quantize_iq1_tn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 55366ab1..b5e3cba3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2078,15 +2078,16 @@ template <int nrc> struct Q8_K64 { Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) { const float * dptr = (const float *)info.src1_row(iy); - std::memcpy(d + 4*iy, dptr, 4*sizeof(float)); - y[iy] = (const int8_t *)(dptr + 4); + std::memcpy(d + 8*iy, dptr, 8*sizeof(float)); + y[iy] = (const int8_t *)(dptr + 8); } } inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); } - inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 4*iy); } + inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); } + inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); } - float d[4*nrc_y]; + float d[8*nrc_y]; const int8_t * y[nrc_y]; }; @@ -2121,17 +2122,17 @@ struct DequantizerIQ1BN { auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3); auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3); #ifdef HAVE_FANCY_SIMD - v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8); - v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8); + v1 = _mm256_permutex2var_epi8(val1, bmask, val2); + v2 = _mm256_permutex2var_epi8(val3, bmask, val4); #else - v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8); - v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8); + v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216); + v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216); #endif } }; -template <int nrc_y> +template <int nrc_y, bool is_iq1_tn> IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; Q8_K64<nrc_y> q8(info); @@ -2143,11 +2144,18 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const const auto m1_16 = _mm256_set1_epi16(1); #endif - const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx); + const block_iq1_bn * x; + const char * cx0 = (const char *)vx; + float scale; for (int ix = 0; ix < nrc_x; ++ix) { - x = (const block_iq1_bn *)((const char *)vx + ix*bx); + const char * cx = cx0 + ix*bx; + if constexpr (is_iq1_tn) { + scale = GGML_FP16_TO_FP32(*(const ggml_half *)cx); + cx += sizeof(ggml_half); + } + x = (const block_iq1_bn *)cx; if constexpr (nrc_y == 1) { __m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256(); @@ -2155,17 +2163,13 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); #if defined __AVX512VNNI__ && defined __AVX512VL__ - auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]); - auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]); - auto dot3 = _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2]); - auto dot4 = _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]); - acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, deq.m1_8, dot1), deq.m1_8, dot2); - acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, deq.m1_8, dot3), deq.m1_8, dot4); + acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1)); + acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3)); #else - auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]))); - auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]))); + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3))); acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1)); acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2)); #endif @@ -2183,17 +2187,16 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const for (int iy = 0; iy < nrc_y; ++iy) { #if defined __AVX512VNNI__ && defined __AVX512VL__ - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]); - auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]); - auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]); - accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( - accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4); + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], + val[0], q8.load_quants(iy, i, 0)), + val[1], q8.load_quants(iy, i, 1)), + val[2], q8.load_quants(iy, i, 2)), + val[3], q8.load_quants(iy, i, 3)); #else - auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]))); - auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]))); + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3))); dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2)); accd[iy] = _mm256_add_epi32(dot1, accd[iy]); #endif @@ -2204,13 +2207,12 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const if (i < nb) { deq.prepare_iq1bn_quants(x + i, val[0], val[1]); for (int iy = 0; iy < nrc_y; ++iy) { - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]); #if defined __AVX512VNNI__ && defined __AVX512VL__ - accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2); + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], + val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1)); #else - auto dot = _mm256_madd_epi16(m1_16, - _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2))); + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1)))); accd[iy] = _mm256_add_epi32(dot, accd[iy]); #endif } @@ -2219,8 +2221,12 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const for (int iy = 0; iy < nrc_y; ++iy) { auto vd = q8.scale(iy); auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); - auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi)); - info.store(ix, iy, hsum_float_4(sumf)); + auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy)); + if constexpr (is_iq1_tn) { + info.store(ix, iy, scale*hsum_float_4(sumf)); + } else { + info.store(ix, iy, hsum_float_4(sumf)); + } } } @@ -2236,8 +2242,8 @@ struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn> { make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2); } IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const { - val[0] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); - val[1] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask3), mf_8); + val[0] = _mm256_and_si256(q2_1, mask2); + val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2); } IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const { auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs); @@ -2270,15 +2276,15 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const for (int i = 0; i < nb/2; ++i) { deq.prepare4(i, val); #if defined __AVX512VNNI__ && defined __AVX512VL__ - acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])), - deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])); - acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])), - deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3])); + acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)), + val[1], q8.load_quants(0, i, 1)); + acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)), + val[3], q8.load_quants(0, i, 3)); #else - auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]))); - auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]))); + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3))); acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1)); acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2)); #endif @@ -2292,18 +2298,17 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const for (int i = 0; i < nb/2; ++i) { deq.prepare4(i, val); for (int iy = 0; iy < nrc_y; ++iy) { - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]); - auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]); - auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]); #if defined __AVX512VNNI__ && defined __AVX512VL__ - accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( - accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4); + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], + val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)), + val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3)); #else auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( - _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)), - _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4)))); - accd[iy] = i > 0 ? _mm256_add_epi32(dot, accd[iy]) : dot; + _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))), + _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3))))); + accd[iy] = _mm256_add_epi32(dot, accd[iy]); #endif } } @@ -2312,13 +2317,13 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const if (i < nb) { deq.prepare2(i, val); for (int iy = 0; iy < nrc_y; ++iy) { - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]); #if defined __AVX512VNNI__ && defined __AVX512VL__ - accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2); + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)), + val[1], q8.load_quants(iy, i/2, 1)); #else - dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2))); - accd[iy] = _mm256_add_epi32(dot1, accd[iy]); + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 0)))); + accd[iy] = _mm256_add_epi32(dot, accd[iy]); #endif } } @@ -2326,7 +2331,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const for (int iy = 0; iy < nrc_y; ++iy) { auto vd = q8.scale(iy); auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); - auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi)); + auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy)); info.store(ix, iy, hsum_float_4(sumf)); } } @@ -3733,14 +3738,26 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { break; case GGML_TYPE_IQ1_BN: assert (ne00 % QK_IQ1BN == 0); - mm.funcs[0] = mul_mat_iq1bn_q8_K64<1>; - mm.funcs[1] = mul_mat_iq1bn_q8_K64<2>; - mm.funcs[2] = mul_mat_iq1bn_q8_K64<3>; - mm.funcs[3] = mul_mat_iq1bn_q8_K64<4>; - mm.funcs[4] = mul_mat_iq1bn_q8_K64<5>; - mm.funcs[5] = mul_mat_iq1bn_q8_K64<6>; - mm.funcs[6] = mul_mat_iq1bn_q8_K64<7>; - mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>; + mm.funcs[0] = mul_mat_iq1bn_q8_K64<1, false>; + mm.funcs[1] = mul_mat_iq1bn_q8_K64<2, false>; + mm.funcs[2] = mul_mat_iq1bn_q8_K64<3, false>; + mm.funcs[3] = mul_mat_iq1bn_q8_K64<4, false>; + mm.funcs[4] = mul_mat_iq1bn_q8_K64<5, false>; + mm.funcs[5] = mul_mat_iq1bn_q8_K64<6, false>; + mm.funcs[6] = mul_mat_iq1bn_q8_K64<7, false>; + mm.funcs[7] = mul_mat_iq1bn_q8_K64<8, false>; + expected_typeB = GGML_TYPE_Q8_K64; + break; + case GGML_TYPE_IQ1_TN: + assert (ne00 % QK_IQ1BN == 0); + mm.funcs[0] = mul_mat_iq1bn_q8_K64<1, true>; + mm.funcs[1] = mul_mat_iq1bn_q8_K64<2, true>; + mm.funcs[2] = mul_mat_iq1bn_q8_K64<3, true>; + mm.funcs[3] = mul_mat_iq1bn_q8_K64<4, true>; + mm.funcs[4] = mul_mat_iq1bn_q8_K64<5, true>; + mm.funcs[5] = mul_mat_iq1bn_q8_K64<6, true>; + mm.funcs[6] = mul_mat_iq1bn_q8_K64<7, true>; + mm.funcs[7] = mul_mat_iq1bn_q8_K64<8, true>; expected_typeB = GGML_TYPE_Q8_K64; break; case GGML_TYPE_IQ2_BN: @@ -5825,16 +5842,17 @@ template <int nrc> struct Q8_K64 { Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) { auto dptr = (const float *)info.src1_row(iy); - std::memcpy(d + 4*iy, dptr, 4*sizeof(float)); - y[iy] = (const int8_t *)(dptr + 4); + std::memcpy(d + 8*iy, dptr, 8*sizeof(float)); + y[iy] = (const int8_t *)(dptr + 8); } } inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); } inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); } - inline float32x4_t scale(int iy) const { return vld1q_f32(d + 4*iy); } + inline float32x4_t scale(int iy) const { return vld1q_f32(d + 8*iy); } + inline float32x4_t minus(int iy) const { return vld1q_f32(d + 8*iy + 4); } - float d[4*nrc_y]; + float d[8*nrc_y]; const int8_t * y[nrc_y]; }; @@ -5866,9 +5884,17 @@ struct DequantizerIQ1BN { v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1); } } + + IQK_ALWAYS_INLINE void prepare_iq1bn_quants_nosub(const block_iq1_bn * x, int8x16x4_t& v) const { + auto data = vld1q_u8((const uint8_t *)x); + for (int k = 0; k < 4; ++k) { + auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]); + v.val[k] = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6)); + } + } }; -template <int nrc_y> +template <int nrc_y, bool is_iq1_tn> static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; @@ -5878,19 +5904,25 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn int32x4_t accd[nrc_y]; int8x16x4_t v1, v2; - const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx); + float scale; for (int ix = 0; ix < nrc_x; ++ix) { - x = (const block_iq1_bn *)((const char *)vx + ix*bx); + const char * cx = ((const char *)vx + ix*bx); + if constexpr (is_iq1_tn) { + scale = GGML_FP16_TO_FP32(*(const ggml_half *)cx); + cx += sizeof(ggml_half); + } + + const block_iq1_bn * x = (const block_iq1_bn *)cx; if constexpr (nrc_y == 1) { int32x4_t acc[4] = {}; for (int i = 0; i < nb/2; ++i) { - deq.prepare_iq1bn_quants(x+2*i+0, v1); + deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1); auto q = q8.load_quants64(0, i, 0); for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]); - deq.prepare_iq1bn_quants(x+2*i+1, v2); + deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2); q = q8.load_quants64(0, i, 1); for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]); } @@ -5902,8 +5934,8 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int i = 0; i < nb/2; ++i) { - deq.prepare_iq1bn_quants(x+2*i+0, v1); - deq.prepare_iq1bn_quants(x+2*i+1, v2); + deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1); + deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2); for (int iy = 0; iy < nrc_y; ++iy) { auto q = q8.load_quants(iy, i, 0); @@ -5919,7 +5951,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } int i = 2*(nb/2); if (i < nb) { - deq.prepare_iq1bn_quants(x+i, v1); + deq.prepare_iq1bn_quants_nosub(x+i, v1); if constexpr (nrc_y == 1) { auto q = q8.load_quants(0, i/2, 0); for (int j = 0; j < 4; ++j) { @@ -5936,7 +5968,11 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + if constexpr (is_iq1_tn) { + info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + } else { + info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + } } } @@ -5964,10 +6000,10 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int j = 0; j < 2; ++j) { auto q = q8.load_quants64(0, i, j); auto q2bits = vld1q_u8(x[2*i+j].qs); - v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + v1.val[0] = vandq_s8(q2bits, mask2); + v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2); + v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2); + v1.val[3] = vshrq_n_u8(q2bits, 6); acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]); acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]); acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]); @@ -5980,15 +6016,15 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0); for (int i = 0; i < nb/2; ++i) { auto q2bits = vld1q_u8(x[2*i+0].qs); - v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + v1.val[0] = vandq_s8(q2bits, mask2); + v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2); + v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2); + v1.val[3] = vshrq_n_u8(q2bits, 6); q2bits = vld1q_u8(x[2*i+1].qs); - v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + v2.val[0] = vandq_s8(q2bits, mask2); + v2.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2); + v2.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2); + v2.val[3] = vshrq_n_u8(q2bits, 6); for (int iy = 0; iy < nrc_y; ++iy) { auto q = q8.load_quants(iy, i, 0); accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); @@ -6005,10 +6041,10 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn if (i < nb) { auto q2bits = vld1q_u8(x[i].qs); int8x16x4_t v1; - v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + v1.val[0] = vandq_s8(q2bits, mask2); + v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2); + v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2); + v1.val[3] = vshrq_n_u8(q2bits, 6); for (int iy = 0; iy < nrc_y; ++iy) { auto q = q8.load_quants(iy, i/2, 0); accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); @@ -6018,7 +6054,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy])))); } } } @@ -6133,14 +6169,25 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { MulMat::set_functions<DequantizerIQ3S>(m); break; case GGML_TYPE_IQ1_BN: - m.funcs[0] = mul_mat_iq1bn_q8_K64<1>; - m.funcs[1] = mul_mat_iq1bn_q8_K64<2>; - m.funcs[2] = mul_mat_iq1bn_q8_K64<3>; - m.funcs[3] = mul_mat_iq1bn_q8_K64<4>; - m.funcs[4] = mul_mat_iq1bn_q8_K64<5>; - m.funcs[5] = mul_mat_iq1bn_q8_K64<6>; - m.funcs[6] = mul_mat_iq1bn_q8_K64<7>; - m.funcs[7] = mul_mat_iq1bn_q8_K64<8>; + m.funcs[0] = mul_mat_iq1bn_q8_K64<1, false>; + m.funcs[1] = mul_mat_iq1bn_q8_K64<2, false>; + m.funcs[2] = mul_mat_iq1bn_q8_K64<3, false>; + m.funcs[3] = mul_mat_iq1bn_q8_K64<4, false>; + m.funcs[4] = mul_mat_iq1bn_q8_K64<5, false>; + m.funcs[5] = mul_mat_iq1bn_q8_K64<6, false>; + m.funcs[6] = mul_mat_iq1bn_q8_K64<7, false>; + m.funcs[7] = mul_mat_iq1bn_q8_K64<8, false>; + expected_Btype = GGML_TYPE_Q8_K64; + break; + case GGML_TYPE_IQ1_TN: + m.funcs[0] = mul_mat_iq1bn_q8_K64<1, true>; + m.funcs[1] = mul_mat_iq1bn_q8_K64<2, true>; + m.funcs[2] = mul_mat_iq1bn_q8_K64<3, true>; + m.funcs[3] = mul_mat_iq1bn_q8_K64<4, true>; + m.funcs[4] = mul_mat_iq1bn_q8_K64<5, true>; + m.funcs[5] = mul_mat_iq1bn_q8_K64<6, true>; + m.funcs[6] = mul_mat_iq1bn_q8_K64<7, true>; + m.funcs[7] = mul_mat_iq1bn_q8_K64<8, true>; expected_Btype = GGML_TYPE_Q8_K64; break; case GGML_TYPE_IQ2_BN: diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 0968becf..9b39a490 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -119,6 +119,54 @@ void quantize_row_iq1_bn(const float * x, void * y, int64_t k) { quantize_iq1_bn(x, y, 1, k, nullptr); } +void quantize_row_iq1_tn_ref(const float * x, block_iq1_tn * y, int64_t k) { + quantize_iq1_tn(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq1_tn(const float * x, void * y, int64_t k) { + quantize_iq1_tn(x, y, 1, k, nullptr); +} + +size_t quantize_iq1_tn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row >= 2*QK_K); // so we have space for the scale + int nblock = n_per_row/QK_IQ1BN; + float tmp[QK_IQ1BN]; + char * qrow = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ1_TN, n_per_row); + IQ1BNQuantizer iq1bn; + for (int row = 0; row < nrows; ++row) { + float max = fabsf(src[0]); + for (int j = 1; j < n_per_row; ++j) max = std::max(max, fabsf(src[j])); + if (!(max > 0)) printf("%s: found max = %g?\n", __func__, max); + //GGML_ASSERT(max > 0); + *(ggml_half *)qrow = GGML_FP32_TO_FP16(max); + block_iq1_bn * y = (block_iq1_bn *)(qrow + sizeof(ggml_half)); + const float * xb = src; + for (int ib = 0; ib < nblock; ++ib) { + for (int j = 0; j < QK_IQ1BN; ++j) tmp[j] = xb[j] < -0.5f*max ? -1 : xb[j] <= 0.5f*max ? 0 : 1; + iq1bn.quantize_one_row_1bn(tmp, y, QK_IQ1BN, imatrix); + ++y; + xb += QK_IQ1BN; + } + src += n_per_row; + qrow += row_size; + } + return nrows*row_size; +} + +void dequantize_row_iq1_tn(const block_iq1_tn * x, float * y, int64_t k) { + float scale = GGML_FP16_TO_FP32(*(const ggml_half *)x); + const block_iq1_bn * iq1bn = (const block_iq1_bn *)((const char *)x + sizeof(ggml_half)); + dequantize_row_iq1_bn(iq1bn, y, k); + for (int j = 0; j < int(k); ++j) y[j] *= scale; +} + +void vec_dot_iq1_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + float scale = GGML_FP16_TO_FP32(*(const ggml_half *)vx); + ggml_vec_dot_iq1_bn_q8_K64(n, s, bs, (const void *)((const char *)vx + sizeof(ggml_half)), bx, vy, by, nrc); + *s *= scale; +} + void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) { assert(k%QK_IQ1BN == 0); int nblock = k / QK_IQ1BN; @@ -331,8 +379,10 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) { + GGML_ASSERT(k >= 8*QK_IQ1BN); + float * dptr = (float *)y; - auto qs = (int8_t *)(dptr + 4); + auto qs = (int8_t *)(dptr + 8); #ifdef __ARM_NEON static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; auto shuffle = vld1q_u8(k_shuffle); @@ -351,16 +401,22 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) { vid[i] = vdupq_n_f32(id); } int8x16x4_t q; + int32x4_t qsum = {}; + const int8x16_t m1 = vdupq_n_s8(1); for (int j = 0; j < k; j += 16) { for (int i = 0; i < 4; ++i) { auto val = vld1q_f32(x + j + 4*i); val = vmulq_f32(vid[i], val); - q.val[i] = vreinterpretq_s8_s32(vcvtnq_s32_f32(val)); + auto ival = vcvtnq_s32_f32(val); + q.val[i] = vreinterpretq_s8_s32(ival); } auto qi = vqtbl4q_s8(q, shuffle); + qsum = ggml_vdotq_s32(qsum, qi, m1); vst1q_s8(qs, qi); qs += 16; } + auto sumf = vmulq_f32(vld1q_f32(dptr), vcvtq_f32_s32(qsum)); + vst1q_f32(dptr + 4, sumf); #elif defined __AVX__ __m128 max[4] = {}; __m128 sign_bit = _mm_set1_ps(-0.f); @@ -381,6 +437,9 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) { vid[i] = _mm_set1_ps(id); } __m128i q[4]; + __m128i sums = _mm_setzero_si128(); + __m128i m1_8 = _mm_set1_epi8(1); + __m128i m1_16 = _mm_set1_epi16(1); for (int j = 0; j < k; j += 16) { for (int i = 0; i < 4; ++i) { auto val = _mm_loadu_ps(x + j + 4*i); @@ -390,9 +449,13 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) { auto q1 = _mm_packs_epi32(q[0], q[1]); auto q2 = _mm_packs_epi32(q[2], q[3]); auto qi = _mm_packs_epi16(q1, q2); + auto aux = _mm_maddubs_epi16(m1_8, qi); + sums = _mm_add_epi32(sums, _mm_madd_epi16(m1_16, aux)); _mm_storeu_si128((__m128i *)qs, qi); qs += 16; } + auto minus = _mm_mul_ps(_mm_loadu_ps(dptr), _mm_cvtepi32_ps(sums)); + _mm_storeu_ps(dptr + 4, minus); #else float aux[4] = {0.f, 0.f, 0.f, 0.f}; for (int j = 0; j < k; j += 16) { @@ -407,11 +470,16 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) { dptr[i] = aux[i]/127; aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f; } + int32_t sum[4] = {}; for (int j = 0; j < k; j += 16) { for (int i = 0; i < 4; ++i) { - for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]); + for (int l = 0; l < 4; ++l) { + qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]); + sum[i] += qs[j+4*i+l]; + } } } + for (int i = 0; i < 4; ++i) dptr[4+i] = dptr[i]*sum[i]; #endif } diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index d7a0748f..e5c16fc9 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -49,6 +49,12 @@ size_t quantize_iq2_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst void dequantize_row_iq2_tn(const block_iq2_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq2_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq1_tn_ref(const float * GGML_RESTRICT x, block_iq1_tn * GGML_RESTRICT y, int64_t k); +void quantize_row_iq1_tn(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq1_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq1_tn(const block_iq1_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq1_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); #ifdef __cplusplus diff --git a/include/llama.h b/include/llama.h index a9af4c48..02d94b6c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -166,14 +166,16 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ1_BN = 36, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ2_BN = 37, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ2_K = 38, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ3_K = 39, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ4_K = 40, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ5_K = 41, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ6_K = 42, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ2_TN = 43, // except 1d tensors + // + LLAMA_FTYPE_MOSTLY_IQ1_BN = 136, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_BN = 137, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_K = 138, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_K = 139, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_K = 140, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ5_K = 141, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ6_K = 142, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_TN = 143, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ1_TN = 144, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index 768aafa7..bb9b6848 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3788,6 +3788,7 @@ struct llama_model_loader { case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break; case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break; + case GGML_TYPE_IQ1_TN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_TN; break; case GGML_TYPE_IQ2_TN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_TN; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; @@ -4497,8 +4498,9 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ5_K: return "IQ5_K - 5.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ6_K: return "IQ6_K - 6.6 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_BN: return "IQ1_BN - 1.625 bpw Bitnet"; + case LLAMA_FTYPE_MOSTLY_IQ1_TN: return "IQ1_TN - 1.6875 bpw TriLM"; case LLAMA_FTYPE_MOSTLY_IQ2_BN: return "IQ2_BN - 2.00 bpw Bitnet"; - case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQT_BN - 2.06 bpw TriLM"; + case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQ2_TN - 2.06 bpw TriLM"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; @@ -15644,7 +15646,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN) { new_type = GGML_TYPE_IQ4_NL; } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_TN) { + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_TN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_TN) { new_type = GGML_TYPE_Q4_K; } else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || @@ -15856,7 +15858,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_IQ2_TN || - new_type == GGML_TYPE_IQ6_K) { + new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ1_TN) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -15881,6 +15883,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_TN: case GGML_TYPE_IQ2_TN: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: @@ -15991,6 +15994,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; case LLAMA_FTYPE_MOSTLY_IQ1_BN: default_type = GGML_TYPE_IQ1_BN; break; case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break; + case LLAMA_FTYPE_MOSTLY_IQ1_TN: default_type = GGML_TYPE_IQ1_TN; break; case LLAMA_FTYPE_MOSTLY_IQ2_TN: default_type = GGML_TYPE_IQ2_TN; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; |