diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-02-06 14:08:52 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-06 14:08:52 +0200 |
commit | 7f61b3068e18728e5e7e2b95546ff03dd2fd41ac (patch) | |
tree | f175a942a6ebd2d2d8b08c46fa71d9f6fbad50e7 /ggml/src/iqk/iqk_mul_mat.cpp | |
parent | a6f9f2ec9af92b5a13f035db054aac2fd2efaee7 (diff) |
IQ1_M_R4: better 1.75 bpw quants (#187)
* iq1_m_r4: basics (quantize/dequantize)
* iq1_m_r4: Zen4 gemm
* iq1_m_r4: neon gemm
* iq1_m_r4: switch to q8_0_x4 also on AVX2/Zen4
With the deltas being per group of 8, we cannot make use
of the q8 sums stored in q8_1, so we get a tiny gain by
using q8_0_x4.
* iq1_m_r4: rename mul_mat_iq1_m_r4_q8_1 to mul_mat_iq1_m_r4_q8_0
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 197 |
1 files changed, 197 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ea8e8274..57024602 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -260,6 +260,7 @@ struct MulMat { case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_IQ3_S_R4: return 4; case GGML_TYPE_IQ4_NL_R4: case GGML_TYPE_Q5_0_R4: @@ -295,6 +296,7 @@ struct MulMat { case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_IQ2_BN_R4: return 4; case GGML_TYPE_IQ4_XS_R4: case GGML_TYPE_Q4_0_R4: @@ -3609,6 +3611,102 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI } } +template <int nrc_y> +static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_0_x4> q8(info); + int nb = n / 32; + GGML_ASSERT(nb%4 == 0); + auto shuffle0 = _mm256_set_epi64x(0x0909090909090909, 0x0808080808080808, 0x0101010101010101, 0x0000000000000000); + auto step = _mm256_set1_epi8(2); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + __m256i qx[4]; + __m256 acc[nrc_y] = {}; + auto ms = _mm_set1_epi8(0x08); + float d8[4*nrc_y]; + union { __m256i vec; uint16_t val[16]; } helper; + for (int ix= 0; ix < nrc_x; ix += 4) { + auto dptr = (const ggml_half *)((const char *)vx + ix*bx); + auto d1 = _mm_mul_ps(_mm_set1_ps(0.125f), _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr))); + auto x = (const block_iq1_m_r4 *)(dptr + 4); + for (int ib = 0; ib < nb/4; ++ib) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm_storeu_ps(d8 + 4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].d))); + } + for (int k = 0; k < 4; ++k) { + auto qh = (const uint32_t *)x[4*ib+k].qh; + auto idxh = _mm_set_epi32(qh[1] >> 4, qh[1], qh[0] >> 4, qh[0]); + auto scales4 = _mm_set1_epi32(((const uint32_t *)x[4*ib+k].scales)[0]); + scales4 = _mm_and_si128(_mm_srlv_epi32(scales4, _mm_set_epi32(4, 0, 4, 0)), _mm_set1_epi8(0xf)); + scales4 = _mm_cvtepu8_epi16(scales4); + auto scales = MM256_SET_M128I(_mm_unpackhi_epi16(scales4, scales4), _mm_unpacklo_epi16(scales4, scales4)); + + auto signs128 = _mm_or_si128(_mm_cmpeq_epi8(_mm_and_si128(idxh, ms), ms), _mm_set1_epi8(1)); + signs128 = _mm_add_epi8(_mm_set1_epi8(-8), signs128); + auto signs = MM256_SET_M128I(signs128, signs128); + auto idxl = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[4*ib+k].qs)); + idxh = _mm_and_si128(idxh, _mm_set1_epi8(0x07)); + helper.vec = _mm256_or_si256(idxl, _mm256_slli_epi16(_mm256_cvtepu8_epi16(idxh), 8)); + qx[0] = _mm256_set_epi64x(iq1s_grid_us[helper.val[ 9]], iq1s_grid_us[helper.val[ 8]], + iq1s_grid_us[helper.val[ 1]], iq1s_grid_us[helper.val[ 0]]); + qx[1] = _mm256_set_epi64x(iq1s_grid_us[helper.val[13]], iq1s_grid_us[helper.val[12]], + iq1s_grid_us[helper.val[ 5]], iq1s_grid_us[helper.val[ 4]]); + qx[2] = _mm256_set_epi64x(iq1s_grid_us[helper.val[11]], iq1s_grid_us[helper.val[10]], + iq1s_grid_us[helper.val[ 3]], iq1s_grid_us[helper.val[ 2]]); + qx[3] = _mm256_set_epi64x(iq1s_grid_us[helper.val[15]], iq1s_grid_us[helper.val[14]], + iq1s_grid_us[helper.val[ 7]], iq1s_grid_us[helper.val[ 6]]); + qx[0] = _mm256_add_epi8(_mm256_slli_epi16(qx[0], 3), _mm256_shuffle_epi8(signs, shuffle0)); + auto shuffle = _mm256_add_epi8(shuffle0, step); + qx[2] = _mm256_add_epi8(_mm256_slli_epi16(qx[2], 3), _mm256_shuffle_epi8(signs, shuffle)); + shuffle = _mm256_add_epi8(shuffle, step); + qx[1] = _mm256_add_epi8(_mm256_slli_epi16(qx[1], 3), _mm256_shuffle_epi8(signs, shuffle)); + shuffle = _mm256_add_epi8(shuffle, step); + qx[3] = _mm256_add_epi8(_mm256_slli_epi16(qx[3], 3), _mm256_shuffle_epi8(signs, shuffle)); + auto s0 = _mm256_sign_epi8(qx[0], qx[0]); + auto s1 = _mm256_sign_epi8(qx[1], qx[1]); + auto s2 = _mm256_sign_epi8(qx[2], qx[2]); + auto s3 = _mm256_sign_epi8(qx[3], qx[3]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ib].qs + k); + auto y1 = _mm256_shuffle_epi32(y, 0x44); + auto y2 = _mm256_shuffle_epi32(y, 0xee); +#ifdef HAVE_FANCY_SIMD + // 0,0, 1,1, 0,0, 1,1 as int32_t + auto sumi1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), + s0, _mm256_sign_epi8(y1, qx[0])), s1, _mm256_sign_epi8(y2, qx[1])); + // 2,2, 3,3, 2,2, 3,3 as int32_t + auto sumi2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), + s2, _mm256_sign_epi8(y1, qx[2])), s3, _mm256_sign_epi8(y2, qx[3])); + auto sumi = _mm256_packs_epi32(sumi1, sumi2); +#else + // 4 x row 0, 4 x row 1, 4 x row 0, 4 x row 1 + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(s0, _mm256_sign_epi8(y1, qx[0])), + _mm256_maddubs_epi16(s1, _mm256_sign_epi8(y2, qx[1]))); + // 4 x row 2, 4 x row 3, 4 x row 2, 4 x row 3 + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(s2, _mm256_sign_epi8(y1, qx[2])), + _mm256_maddubs_epi16(s3, _mm256_sign_epi8(y2, qx[3]))); + // 0,0, 1,1, 0,0, 1,1 as int32_t + sumi1 = _mm256_madd_epi16(m1, sumi1); + // 2,2, 3,3, 2,2, 3,3 as int32_t + sumi2 = _mm256_madd_epi16(m1, sumi2); + // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 as int16_t + auto sumi = _mm256_packs_epi32(sumi1, sumi2); +#endif + sumi = _mm256_madd_epi16(scales, sumi); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[4*iy+k]), _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumf = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, _mm_mul_ps(d1, sumf)); + acc[iy] = _mm256_setzero_ps(); + } + } +} + #ifdef HAVE_FANCY_SIMD template <int nrc_y> static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -9081,6 +9179,21 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #endif expected_typeB = GGML_TYPE_Q8_1_X4; break; + case GGML_TYPE_IQ1_M_R4: + assert (ne00 % QK4_NL == 0); + mm.funcs[0] = mul_mat_iq1_m_r4_q8_0<1>; + mm.funcs[1] = mul_mat_iq1_m_r4_q8_0<2>; + mm.funcs[2] = mul_mat_iq1_m_r4_q8_0<3>; + mm.funcs[3] = mul_mat_iq1_m_r4_q8_0<4>; + mm.funcs[4] = mul_mat_iq1_m_r4_q8_0<5>; + mm.funcs[5] = mul_mat_iq1_m_r4_q8_0<6>; + mm.funcs[6] = mul_mat_iq1_m_r4_q8_0<7>; + mm.funcs[7] = mul_mat_iq1_m_r4_q8_0<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_iq1_m_r4_q8_0<16>; +#endif + expected_typeB = GGML_TYPE_Q8_0_X4; + break; default: return false; @@ -12093,6 +12206,85 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI } template <int nrc_y> +static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_0_x4> q8(info); + int nb = n / 32; + GGML_ASSERT(nb%4 == 0); + int8x16_t qx[8]; + int32x4_t acc[nrc_y] = {}; + auto shuffle0 = uint32x4_t{0x00000000, 0x01010101, 0x02020202, 0x03030303}; + auto step = vdupq_n_u8(4); + auto ms = vdupq_n_u8(0x08); + auto mask = vdupq_n_s8(0x18); + float d8[4*nrc_y]; + for (int ix= 0; ix < nrc_x; ix += 4) { + auto dptr = (const ggml_half *)((const char *)vx + ix*bx); + auto d1 = vmulq_f32(vdupq_n_f32(0.125f), vcvt_f32_f16(vld1_f16((const float16_t *)dptr))); + auto x = (const block_iq1_m_r4 *)(dptr + 4); + for (int ib = 0; ib < nb/4; ++ib) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = vld1_f16((const float16_t *)q8.y[iy][ib].d); + vst1q_f32(d8+4*iy, vcvt_f32_f16(scales)); + } + for (int k = 0; k < 4; ++k) { + auto scales4 = vdup_n_u32(((const uint32_t *)x[4*ib+k].scales)[0]); + scales4 = vand_u8(vshl_u32(scales4, int32x2_t{0, -4}), vdup_n_u8(0xf)); + auto scales16 = vmovl_u8(scales4); + auto scales1 = vmovl_u16(vget_low_u16(scales16)); + auto scales2 = vmovl_u16(vget_high_u16(scales16)); + auto qh = (const uint32_t *)x[4*ib+k].qh; + auto idxh = uint32x4_t{qh[0], qh[0] >> 4, qh[1], qh[1] >> 4}; + auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(idxh, ms), ms), vdupq_n_u8(1))); + signs = vaddq_s8(signs, vdupq_n_s8(-8)); + qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); + qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]}); + qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]}); + qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)], + iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]}); + auto shuffle = shuffle0; + for (int j = 0; j < 4; ++j) { + auto s = vqtbl1q_s8(signs, shuffle); + qx[2*j+1] = vaddq_s8(s, vandq_s8(vshrq_n_s8(qx[2*j+0], 1), mask)); + qx[2*j+0] = vaddq_s8(s, vandq_s8(vshlq_n_s8(qx[2*j+0], 3), mask)); + shuffle = vaddq_u8(shuffle, step); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k); + auto sumi1 = vdupq_n_s32(0); + auto sumi2 = vdupq_n_s32(0); + sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[0]), y.val[0], 0); + sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[1]), y.val[0], 1); + sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[2]), y.val[0], 2); + sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[3]), y.val[0], 3); + sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[4]), y.val[1], 0); + sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[5]), y.val[1], 1); + sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[6]), y.val[1], 2); + sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[7]), y.val[1], 3); + auto sumi = vmlaq_s32(vmlaq_s32(vdupq_n_s32(0), sumi1, scales1), sumi2, scales2); + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), vcvtq_f32_s32(sumi)); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(d1, acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template <int nrc_y> static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8<nrc_y, block_q8_K> q8(info); @@ -13717,6 +13909,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.func16 = mul_mat_iq1_s_r4_q8_1<16>; expected_Btype = GGML_TYPE_Q8_1_X4; break; + case GGML_TYPE_IQ1_M_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_m_r4_q8_0); + m.func16 = mul_mat_iq1_m_r4_q8_0<16>; + expected_Btype = GGML_TYPE_Q8_0_X4; + break; case GGML_TYPE_IQ3_XXS_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k); m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>; |