summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-02-06 14:08:52 +0200
committerGitHub <noreply@github.com>2025-02-06 14:08:52 +0200
commit7f61b3068e18728e5e7e2b95546ff03dd2fd41ac (patch)
treef175a942a6ebd2d2d8b08c46fa71d9f6fbad50e7 /ggml/src/iqk/iqk_mul_mat.cpp
parenta6f9f2ec9af92b5a13f035db054aac2fd2efaee7 (diff)
IQ1_M_R4: better 1.75 bpw quants (#187)
* iq1_m_r4: basics (quantize/dequantize) * iq1_m_r4: Zen4 gemm * iq1_m_r4: neon gemm * iq1_m_r4: switch to q8_0_x4 also on AVX2/Zen4 With the deltas being per group of 8, we cannot make use of the q8 sums stored in q8_1, so we get a tiny gain by using q8_0_x4. * iq1_m_r4: rename mul_mat_iq1_m_r4_q8_1 to mul_mat_iq1_m_r4_q8_0 --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp197
1 files changed, 197 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index ea8e8274..57024602 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -260,6 +260,7 @@ struct MulMat {
case GGML_TYPE_IQ2_S_R4:
case GGML_TYPE_IQ3_XXS_R4:
case GGML_TYPE_IQ1_S_R4:
+ case GGML_TYPE_IQ1_M_R4:
case GGML_TYPE_IQ3_S_R4: return 4;
case GGML_TYPE_IQ4_NL_R4:
case GGML_TYPE_Q5_0_R4:
@@ -295,6 +296,7 @@ struct MulMat {
case GGML_TYPE_IQ3_XXS_R4:
case GGML_TYPE_IQ3_S_R4:
case GGML_TYPE_IQ1_S_R4:
+ case GGML_TYPE_IQ1_M_R4:
case GGML_TYPE_IQ2_BN_R4: return 4;
case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_Q4_0_R4:
@@ -3609,6 +3611,102 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
}
}
+template <int nrc_y>
+static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_0_x4> q8(info);
+ int nb = n / 32;
+ GGML_ASSERT(nb%4 == 0);
+ auto shuffle0 = _mm256_set_epi64x(0x0909090909090909, 0x0808080808080808, 0x0101010101010101, 0x0000000000000000);
+ auto step = _mm256_set1_epi8(2);
+#ifndef HAVE_FANCY_SIMD
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ __m256i qx[4];
+ __m256 acc[nrc_y] = {};
+ auto ms = _mm_set1_epi8(0x08);
+ float d8[4*nrc_y];
+ union { __m256i vec; uint16_t val[16]; } helper;
+ for (int ix= 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
+ auto d1 = _mm_mul_ps(_mm_set1_ps(0.125f), _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr)));
+ auto x = (const block_iq1_m_r4 *)(dptr + 4);
+ for (int ib = 0; ib < nb/4; ++ib) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ _mm_storeu_ps(d8 + 4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].d)));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto qh = (const uint32_t *)x[4*ib+k].qh;
+ auto idxh = _mm_set_epi32(qh[1] >> 4, qh[1], qh[0] >> 4, qh[0]);
+ auto scales4 = _mm_set1_epi32(((const uint32_t *)x[4*ib+k].scales)[0]);
+ scales4 = _mm_and_si128(_mm_srlv_epi32(scales4, _mm_set_epi32(4, 0, 4, 0)), _mm_set1_epi8(0xf));
+ scales4 = _mm_cvtepu8_epi16(scales4);
+ auto scales = MM256_SET_M128I(_mm_unpackhi_epi16(scales4, scales4), _mm_unpacklo_epi16(scales4, scales4));
+
+ auto signs128 = _mm_or_si128(_mm_cmpeq_epi8(_mm_and_si128(idxh, ms), ms), _mm_set1_epi8(1));
+ signs128 = _mm_add_epi8(_mm_set1_epi8(-8), signs128);
+ auto signs = MM256_SET_M128I(signs128, signs128);
+ auto idxl = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)x[4*ib+k].qs));
+ idxh = _mm_and_si128(idxh, _mm_set1_epi8(0x07));
+ helper.vec = _mm256_or_si256(idxl, _mm256_slli_epi16(_mm256_cvtepu8_epi16(idxh), 8));
+ qx[0] = _mm256_set_epi64x(iq1s_grid_us[helper.val[ 9]], iq1s_grid_us[helper.val[ 8]],
+ iq1s_grid_us[helper.val[ 1]], iq1s_grid_us[helper.val[ 0]]);
+ qx[1] = _mm256_set_epi64x(iq1s_grid_us[helper.val[13]], iq1s_grid_us[helper.val[12]],
+ iq1s_grid_us[helper.val[ 5]], iq1s_grid_us[helper.val[ 4]]);
+ qx[2] = _mm256_set_epi64x(iq1s_grid_us[helper.val[11]], iq1s_grid_us[helper.val[10]],
+ iq1s_grid_us[helper.val[ 3]], iq1s_grid_us[helper.val[ 2]]);
+ qx[3] = _mm256_set_epi64x(iq1s_grid_us[helper.val[15]], iq1s_grid_us[helper.val[14]],
+ iq1s_grid_us[helper.val[ 7]], iq1s_grid_us[helper.val[ 6]]);
+ qx[0] = _mm256_add_epi8(_mm256_slli_epi16(qx[0], 3), _mm256_shuffle_epi8(signs, shuffle0));
+ auto shuffle = _mm256_add_epi8(shuffle0, step);
+ qx[2] = _mm256_add_epi8(_mm256_slli_epi16(qx[2], 3), _mm256_shuffle_epi8(signs, shuffle));
+ shuffle = _mm256_add_epi8(shuffle, step);
+ qx[1] = _mm256_add_epi8(_mm256_slli_epi16(qx[1], 3), _mm256_shuffle_epi8(signs, shuffle));
+ shuffle = _mm256_add_epi8(shuffle, step);
+ qx[3] = _mm256_add_epi8(_mm256_slli_epi16(qx[3], 3), _mm256_shuffle_epi8(signs, shuffle));
+ auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ib].qs + k);
+ auto y1 = _mm256_shuffle_epi32(y, 0x44);
+ auto y2 = _mm256_shuffle_epi32(y, 0xee);
+#ifdef HAVE_FANCY_SIMD
+ // 0,0, 1,1, 0,0, 1,1 as int32_t
+ auto sumi1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(),
+ s0, _mm256_sign_epi8(y1, qx[0])), s1, _mm256_sign_epi8(y2, qx[1]));
+ // 2,2, 3,3, 2,2, 3,3 as int32_t
+ auto sumi2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(),
+ s2, _mm256_sign_epi8(y1, qx[2])), s3, _mm256_sign_epi8(y2, qx[3]));
+ auto sumi = _mm256_packs_epi32(sumi1, sumi2);
+#else
+ // 4 x row 0, 4 x row 1, 4 x row 0, 4 x row 1
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(s0, _mm256_sign_epi8(y1, qx[0])),
+ _mm256_maddubs_epi16(s1, _mm256_sign_epi8(y2, qx[1])));
+ // 4 x row 2, 4 x row 3, 4 x row 2, 4 x row 3
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(s2, _mm256_sign_epi8(y1, qx[2])),
+ _mm256_maddubs_epi16(s3, _mm256_sign_epi8(y2, qx[3])));
+ // 0,0, 1,1, 0,0, 1,1 as int32_t
+ sumi1 = _mm256_madd_epi16(m1, sumi1);
+ // 2,2, 3,3, 2,2, 3,3 as int32_t
+ sumi2 = _mm256_madd_epi16(m1, sumi2);
+ // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 as int16_t
+ auto sumi = _mm256_packs_epi32(sumi1, sumi2);
+#endif
+ sumi = _mm256_madd_epi16(scales, sumi);
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[4*iy+k]), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumf = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, _mm_mul_ps(d1, sumf));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -9081,6 +9179,21 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
#endif
expected_typeB = GGML_TYPE_Q8_1_X4;
break;
+ case GGML_TYPE_IQ1_M_R4:
+ assert (ne00 % QK4_NL == 0);
+ mm.funcs[0] = mul_mat_iq1_m_r4_q8_0<1>;
+ mm.funcs[1] = mul_mat_iq1_m_r4_q8_0<2>;
+ mm.funcs[2] = mul_mat_iq1_m_r4_q8_0<3>;
+ mm.funcs[3] = mul_mat_iq1_m_r4_q8_0<4>;
+ mm.funcs[4] = mul_mat_iq1_m_r4_q8_0<5>;
+ mm.funcs[5] = mul_mat_iq1_m_r4_q8_0<6>;
+ mm.funcs[6] = mul_mat_iq1_m_r4_q8_0<7>;
+ mm.funcs[7] = mul_mat_iq1_m_r4_q8_0<8>;
+#ifdef HAVE_FANCY_SIMD
+ mm.func16 = mul_mat_iq1_m_r4_q8_0<16>;
+#endif
+ expected_typeB = GGML_TYPE_Q8_0_X4;
+ break;
default:
return false;
@@ -12093,6 +12206,85 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
}
template <int nrc_y>
+static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_0_x4> q8(info);
+ int nb = n / 32;
+ GGML_ASSERT(nb%4 == 0);
+ int8x16_t qx[8];
+ int32x4_t acc[nrc_y] = {};
+ auto shuffle0 = uint32x4_t{0x00000000, 0x01010101, 0x02020202, 0x03030303};
+ auto step = vdupq_n_u8(4);
+ auto ms = vdupq_n_u8(0x08);
+ auto mask = vdupq_n_s8(0x18);
+ float d8[4*nrc_y];
+ for (int ix= 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
+ auto d1 = vmulq_f32(vdupq_n_f32(0.125f), vcvt_f32_f16(vld1_f16((const float16_t *)dptr)));
+ auto x = (const block_iq1_m_r4 *)(dptr + 4);
+ for (int ib = 0; ib < nb/4; ++ib) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scales = vld1_f16((const float16_t *)q8.y[iy][ib].d);
+ vst1q_f32(d8+4*iy, vcvt_f32_f16(scales));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales4 = vdup_n_u32(((const uint32_t *)x[4*ib+k].scales)[0]);
+ scales4 = vand_u8(vshl_u32(scales4, int32x2_t{0, -4}), vdup_n_u8(0xf));
+ auto scales16 = vmovl_u8(scales4);
+ auto scales1 = vmovl_u16(vget_low_u16(scales16));
+ auto scales2 = vmovl_u16(vget_high_u16(scales16));
+ auto qh = (const uint32_t *)x[4*ib+k].qh;
+ auto idxh = uint32x4_t{qh[0], qh[0] >> 4, qh[1], qh[1] >> 4};
+ auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(idxh, ms), ms), vdupq_n_u8(1)));
+ signs = vaddq_s8(signs, vdupq_n_s8(-8));
+ qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
+ qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]});
+ qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]});
+ qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)],
+ iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]});
+ auto shuffle = shuffle0;
+ for (int j = 0; j < 4; ++j) {
+ auto s = vqtbl1q_s8(signs, shuffle);
+ qx[2*j+1] = vaddq_s8(s, vandq_s8(vshrq_n_s8(qx[2*j+0], 1), mask));
+ qx[2*j+0] = vaddq_s8(s, vandq_s8(vshlq_n_s8(qx[2*j+0], 3), mask));
+ shuffle = vaddq_u8(shuffle, step);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k);
+ auto sumi1 = vdupq_n_s32(0);
+ auto sumi2 = vdupq_n_s32(0);
+ sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[0]), y.val[0], 0);
+ sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[1]), y.val[0], 1);
+ sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[2]), y.val[0], 2);
+ sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[3]), y.val[0], 3);
+ sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[4]), y.val[1], 0);
+ sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[5]), y.val[1], 1);
+ sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
+ sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
+ auto sumi = vmlaq_s32(vmlaq_s32(vdupq_n_s32(0), sumi1, scales1), sumi2, scales2);
+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), vcvtq_f32_s32(sumi));
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vmulq_f32(d1, acc[iy]));
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
@@ -13717,6 +13909,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
m.func16 = mul_mat_iq1_s_r4_q8_1<16>;
expected_Btype = GGML_TYPE_Q8_1_X4;
break;
+ case GGML_TYPE_IQ1_M_R4:
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_m_r4_q8_0);
+ m.func16 = mul_mat_iq1_m_r4_q8_0<16>;
+ expected_Btype = GGML_TYPE_Q8_0_X4;
+ break;
case GGML_TYPE_IQ3_XXS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k);
m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>;