diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-12-02 17:01:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-02 17:01:48 +0100 |
commit | 239a344f9935cb614c6208000edb0815214286a1 (patch) | |
tree | aaa1933857cd9575fecfec14515e43e24305ce1f | |
parent | 6d0462d4a39085a9f9da04e0a5fc7cc9d4578818 (diff) |
Q4_0_R4 (#119)
* Adding iq4_0_r4 - q4_0 repacked
We get PP-512(LLaMA-3.1-8B) = 278 t/s on a Ryzen-7950X CPU,
so ~5-6% faster than iq4_nl_x4.
* q4_0_r4: NEON
Here we get 115.8 t/s, so also ~5% better than iq4_nl_x4.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | examples/quantize/quantize.cpp | 1 | ||||
-rw-r--r-- | ggml/include/ggml.h | 2 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 1 | ||||
-rw-r--r-- | ggml/src/ggml.c | 26 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 170 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 99 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 6 | ||||
-rw-r--r-- | include/llama.h | 1 | ||||
-rw-r--r-- | src/llama.cpp | 13 |
9 files changed, 304 insertions, 15 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 333fae36..9e0dc3cf 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -41,6 +41,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_NL_X4",LLAMA_FTYPE_MOSTLY_IQ4_NL_X4," 4.50 bpw non-linear quantization", }, + { "Q4_0_R4", LLAMA_FTYPE_MOSTLY_Q4_0_R4, " 4.50 bpw non-linear quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", }, { "IQ4_KSS", LLAMA_FTYPE_MOSTLY_IQ4_KSS, " 4.0 bpw non-linear quantization", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index dabb2264..1a46881a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -407,6 +407,7 @@ extern "C" { GGML_TYPE_IQ2_KS = 145, GGML_TYPE_IQ4_KSS = 146, + GGML_TYPE_Q4_0_R4 = 202, GGML_TYPE_IQ4_NL_X4 = 220, GGML_TYPE_COUNT, }; @@ -467,6 +468,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors // + GGML_FTYPE_MOSTLY_Q4_0_R4 = 202, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_NL_X4 = 219, // except 1d tensors }; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 376c97f8..dd43e1c1 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15197,6 +15197,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ4_KS: break; case GGML_TYPE_IQ4_KSS: break; case GGML_TYPE_IQ4_NL_X4: break; + case GGML_TYPE_Q4_0_R4: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c975212e..dfe1017b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1262,6 +1262,23 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q4_0_R4] = { + .type_name = "q4_0_r4", + .blck_size = QK4_NL, + .type_size = sizeof(block_iq4_nl), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_0_r4, + .from_float = quantize_row_q4_0_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_q4_0_r4_ref, + .vec_dot = vec_dot_q4_0_r4_q8_0, +#if GGML_USE_IQK_MULMAT && defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1, +#else + .vec_dot_type = GGML_TYPE_Q8_0, +#endif + .nrows = 1, + .row_meta_size = 0, + }, }; // For internal test use @@ -3921,6 +3938,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_MOSTLY_IQ4_NL_X4: wtype = GGML_TYPE_IQ4_NL_X4;break; + case GGML_FTYPE_MOSTLY_Q4_0_R4: wtype = GGML_TYPE_Q4_0_R4; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break; case GGML_FTYPE_MOSTLY_IQ4_KSS: wtype = GGML_TYPE_IQ4_KSS; break; @@ -10445,6 +10463,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_X4: + case GGML_TYPE_Q4_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: @@ -10888,6 +10907,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_X4: + case GGML_TYPE_Q4_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: @@ -11028,6 +11048,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_X4: + case GGML_TYPE_Q4_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: @@ -14214,6 +14235,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_X4: + case GGML_TYPE_Q4_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: @@ -14594,6 +14616,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_X4: + case GGML_TYPE_Q4_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: @@ -14868,6 +14891,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_X4: + case GGML_TYPE_Q4_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: @@ -15469,6 +15493,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL_X4: + case GGML_TYPE_Q4_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KSS: @@ -22296,6 +22321,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL_X4: result = quantize_iq4_nl_x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0_R4: result = quantize_q4_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KSS: result = quantize_iq4_kss(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index cfda9e18..13ca6724 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2185,6 +2185,104 @@ static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const Data } #endif +#ifdef HAVE_FANCY_SIMD +template <int nrc_y> +static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8<nrc_y, block_q8_1_x4> q8(info); + auto m4 = _mm512_set1_epi8(0xf); + int nb = n / QK4_NL; + GGML_ASSERT(nb%4 == 0); + __m512 acc[2*nrc_y] = {}; + __m512i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_iq4_nl_x4 * iq4l = (const block_iq4_nl_x4 *)((const char *)vx + (ix+0)*bx); + const block_iq4_nl_x4 * iq4h = (const block_iq4_nl_x4 *)((const char *)vx + (ix+4)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-4.f)); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)), + _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)), + _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1); + qx[0] = _mm512_and_si512(bits1, m4); + qx[1] = _mm512_and_si512(bits2, m4); + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } +} +#else +template <int nrc_y> +static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_1_x4> q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m1 = _mm256_set1_epi16(1); + int nb = n / QK4_NL; + GGML_ASSERT(nb%4 == 0); + __m256 acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-4.f)); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1); + auto q1 = _mm256_and_si256(bits1, m4); + auto q2 = _mm256_and_si256(bits2, m4); + auto q3 = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4); + auto q4 = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} +#endif + template <typename Bits> inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { if (j == 0) { @@ -4154,6 +4252,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_iq4_nl_x4_q8_1<8>; expected_typeB = GGML_TYPE_Q8_1; break; + case GGML_TYPE_Q4_0_R4: + assert (ne00 % QK4_NL == 0); + mm.funcs[0] = mul_mat_q4_0_r4_q8_1<1>; + mm.funcs[1] = mul_mat_q4_0_r4_q8_1<2>; + mm.funcs[2] = mul_mat_q4_0_r4_q8_1<3>; + mm.funcs[3] = mul_mat_q4_0_r4_q8_1<4>; + mm.funcs[4] = mul_mat_q4_0_r4_q8_1<5>; + mm.funcs[5] = mul_mat_q4_0_r4_q8_1<6>; + mm.funcs[6] = mul_mat_q4_0_r4_q8_1<7>; + mm.funcs[7] = mul_mat_q4_0_r4_q8_1<8>; + expected_typeB = GGML_TYPE_Q8_1; + break; default: return false; @@ -6646,6 +6756,55 @@ void mul_mat_iq4_nl_x4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& } } +template <int nrc_y> +void mul_mat_q4_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_0_x4> q8(info); + auto m4 = vdupq_n_u8(0xf0); + auto m88 = vdupq_n_u8(0x88); + auto norm = vdupq_n_f32(1.f/16); + int nb = n / QK4_NL; + GGML_ASSERT(nb%4 == 0); + int8x16_t qx[8]; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int k = 0; k < 4; ++k) { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); + auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); + for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]); + qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows + qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19 + qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7 + qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23 + qx[4] = vandq_u8(bits.val[0], m4); // 8..11 + qx[5] = vandq_u8(bits.val[1], m4); // 24..27 + qx[6] = vandq_u8(bits.val[2], m4); // 12..15 + qx[7] = vandq_u8(bits.val[3], m4); // 28..31 + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); + auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); + acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(norm, acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> || std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL> || @@ -6826,6 +6985,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.funcs[7] = mul_mat_iq4_nl_x4_q8_0<8>; expected_Btype = GGML_TYPE_Q8_0; break; + case GGML_TYPE_Q4_0_R4: + m.funcs[0] = mul_mat_q4_0_r4_q8_0<1>; + m.funcs[1] = mul_mat_q4_0_r4_q8_0<2>; + m.funcs[2] = mul_mat_q4_0_r4_q8_0<3>; + m.funcs[3] = mul_mat_q4_0_r4_q8_0<4>; + m.funcs[4] = mul_mat_q4_0_r4_q8_0<5>; + m.funcs[5] = mul_mat_q4_0_r4_q8_0<6>; + m.funcs[6] = mul_mat_q4_0_r4_q8_0<7>; + m.funcs[7] = mul_mat_q4_0_r4_q8_0<8>; + expected_Btype = GGML_TYPE_Q8_0; + break; default: return false; } diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 88b5628c..b9e6ff68 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -3120,7 +3120,7 @@ void vec_dot_iq4_kss_q8_k(int n, float * s, size_t bs, const void * vx, size_t b } // -// ========================================= x4 +// ========================================= iq4_nl_x4 // void quantize_row_iq4_nl_x4_ref(const float * x, block_iq4_nl_x4 * y, int64_t k) { // we assume we are called with 4 rows @@ -3157,24 +3157,10 @@ size_t quantize_iq4_nl_x4(const float * src, void * dst, int64_t nrows, int64_t GGML_ASSERT(nrows%4 == 0); auto row_size_nl = ggml_row_size(GGML_TYPE_IQ4_NL, n_per_row); std::vector<char> qtmp(4*row_size_nl); - //std::vector<float> check1(4*n_per_row), check2(4*n_per_row); char * qrow = (char *)dst; for (int row = 0; row < nrows; row += 4) { quantize_iq4_nl(src, qtmp.data(), 4, n_per_row, imatrix); repack_iq4_nl(4, n_per_row, (const block_iq4_nl *)qtmp.data(), (block_iq4_nl_x4 *)qrow); - //dequantize_row_iq4_nl_x4((const block_iq4_nl_x4 *)qrow, check1.data(), 4*n_per_row); - //dequantize_row_iq4_nl((const block_iq4_nl *)qtmp.data(), check2.data(), 4*n_per_row); - //for (int k = 0; k < 4; ++k) { - // auto x1 = check1.data() + k*n_per_row; - // auto x2 = check2.data() + k*n_per_row; - // int nbad = 0; - // for (int j = 0; j < n_per_row; ++j) { - // if (std::abs(x1[j] - x2[j]) > 1e-8) { - // printf("Oops: %g vs %g\n", x1[j], x2[j]); - // if (++nbad > 20) GGML_ABORT("fatal error"); - // } - // } - //} src += 4*n_per_row; qrow += 4*row_size_nl; } @@ -3217,4 +3203,87 @@ void vec_dot_iq4_nl_x4_q8_0(int n, float * s, size_t bs, const void * vx, size_t GGML_UNUSED(by); } +// +// ========================================= q4_0_r4 +// +void quantize_row_q4_0_r4_ref(const float * x, block_iq4_nl_x4 * y, int64_t k) { + // we assume we are called with 4 rows + quantize_q4_0_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_q4_0_r4(const float * x, void * y, int64_t k) { + // we assume we are called with 4 rows + quantize_q4_0_r4(x, y, 4, k/4, nullptr); +} + +static void repack_q4_0(int nrows, int n_per_row, const block_q4_0 * x, block_iq4_nl_x4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK4_NL == 0); + int nblock = n_per_row/QK4_NL; + const block_q4_0 * x4[4]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d; + for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[4*k+i+ 0] = (x4[k][ib].qs[i+0] & 0xf) | ((x4[k][ib].qs[i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row + y[ib].qs[4*k+i+16] = (x4[k][ib].qs[i+0] >> 4) | ((x4[k][ib].qs[i+ 8] & 0xf0)); // 16...19 + 24...27 from each row + y[ib].qs[4*k+i+32] = (x4[k][ib].qs[i+4] & 0xf) | ((x4[k][ib].qs[i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row + y[ib].qs[4*k+i+48] = (x4[k][ib].qs[i+4] >> 4) | ((x4[k][ib].qs[i+12] & 0xf0)); // 20...23 + 28...31 from each row + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_q4_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + auto row_size_nl = ggml_row_size(GGML_TYPE_IQ4_NL, n_per_row); + std::vector<char> qtmp(4*row_size_nl); + char * qrow = (char *)dst; + for (int row = 0; row < nrows; row += 4) { + quantize_q4_0(src, qtmp.data(), 4, n_per_row, imatrix); + repack_iq4_nl(4, n_per_row, (const block_iq4_nl *)qtmp.data(), (block_iq4_nl_x4 *)qrow); + src += 4*n_per_row; + qrow += 4*row_size_nl; + } + return nrows*row_size_nl; +} + +void dequantize_row_q4_0_r4(const block_iq4_nl_x4 * x, float * y, int64_t k) { + // we assume we are called with 4 rows + int n_per_row = k/4; + int nb = n_per_row/QK4_NL; + float * yk[4]; + for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row; + for (int ib = 0; ib < nb; ++ib) { + for (int k = 0; k < 4; ++k) { + float scale = GGML_FP16_TO_FP32(x[ib].d[k]); + for (int i = 0; i < 4; ++i) { + yk[k][QK4_NL*ib+i+ 0] = scale * ((x[ib].qs[4*k+i+ 0] & 0xf) - 8); + yk[k][QK4_NL*ib+i+ 8] = scale * ((x[ib].qs[4*k+i+ 0] >> 4) - 8); + yk[k][QK4_NL*ib+i+16] = scale * ((x[ib].qs[4*k+i+16] & 0xf) - 8); + yk[k][QK4_NL*ib+i+24] = scale * ((x[ib].qs[4*k+i+16] >> 4) - 8); + yk[k][QK4_NL*ib+i+ 4] = scale * ((x[ib].qs[4*k+i+32] & 0xf) - 8); + yk[k][QK4_NL*ib+i+12] = scale * ((x[ib].qs[4*k+i+32] >> 4) - 8); + yk[k][QK4_NL*ib+i+20] = scale * ((x[ib].qs[4*k+i+48] & 0xf) - 8); + yk[k][QK4_NL*ib+i+28] = scale * ((x[ib].qs[4*k+i+48] >> 4) - 8); + } + } + } +} + +void vec_dot_q4_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q4_0_R4, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 7942cc04..98c9c010 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -69,6 +69,12 @@ size_t quantize_iq4_nl_x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT void dequantize_row_iq4_nl_x4(const block_iq4_nl_x4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq4_nl_x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_q4_0_r4_ref(const float * GGML_RESTRICT x, block_iq4_nl_x4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q4_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q4_0_r4(const block_iq4_nl_x4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q4_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + #ifdef __cplusplus } #endif diff --git a/include/llama.h b/include/llama.h index f94dcb1a..fc034b3a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -180,6 +180,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors // + LLAMA_FTYPE_MOSTLY_Q4_0_R4 = 202, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_NL_X4 = 225, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file diff --git a/src/llama.cpp b/src/llama.cpp index 6eac67b6..9505f56f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3850,6 +3850,7 @@ struct llama_model_loader { case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_NL_X4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL_X4;break; + case GGML_TYPE_Q4_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_R4; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break; case GGML_TYPE_IQ4_KSS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KSS; break; @@ -4555,6 +4556,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL_X4:return "IQ4_NL_X4 - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_Q4_0_R4: return "Q4_0_R4 - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KSS: return "IQ4_KSS - 4.0 bpw"; @@ -15771,6 +15773,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_IQ4_NL_X4) { new_type = GGML_TYPE_IQ4_NL; } + else if (new_type == GGML_TYPE_Q4_0_R4) { + new_type = GGML_TYPE_Q4_0; + } } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || @@ -15941,6 +15946,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix. new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1; } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_0_R4 && qs.has_imatrix && i_layer < n_layer/8) { + new_type = GGML_TYPE_IQ4_NL_X4; + } ++qs.i_ffn_down; } else if (name.find("attn_output.weight") != std::string::npos) { if (qs.params->attn_output_type < GGML_TYPE_COUNT) new_type = qs.params->attn_output_type; @@ -16160,6 +16168,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL_X4:default_type = GGML_TYPE_IQ4_NL_X4;break; + case LLAMA_FTYPE_MOSTLY_Q4_0_R4: default_type = GGML_TYPE_Q4_0_R4; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KSS: default_type = GGML_TYPE_IQ4_KSS; break; @@ -16521,6 +16530,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_NL; else chunk_size_multiplier = 4; } + if (new_type == GGML_TYPE_Q4_0_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0; + else chunk_size_multiplier = 4; + } LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); fflush(stdout); |