summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-07-18 18:55:43 +0200
committerGitHub <noreply@github.com>2025-07-18 18:55:43 +0200
commitcc82006f51e0254279bfd46c3c7d97cb12d7dc18 (patch)
tree3691d7e09ff5dc7a5d7957e84fb83fa2fa702997
parentb94f3af56f6fde4845c968115edaa0ac36e36bb7 (diff)
GEMM for iq1_m (#630)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_gemm_1bit.cpp67
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp4
2 files changed, 70 insertions, 1 deletions
diff --git a/ggml/src/iqk/iqk_gemm_1bit.cpp b/ggml/src/iqk/iqk_gemm_1bit.cpp
index 2c0a1bda..ece76f3c 100644
--- a/ggml/src/iqk/iqk_gemm_1bit.cpp
+++ b/ggml/src/iqk/iqk_gemm_1bit.cpp
@@ -866,6 +866,68 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info,
}
template <int nrc_y>
+void mul_mat_iq1_m_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ __m256i qx[8];
+ __m256 acc[nrc_y] = {};
+ auto scale_shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0504050405040504, 0x0302030203020302, 0x0100010001000100);
+ auto delta_mask = _mm256_set_epi64x(0x8000, 0x0800, 0x0080, 0x0008);
+ iq1m_scale_t scale;
+ union { __m256i vec; int16_t val[16]; } helper;
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto iq1m = (const block_iq1_m *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < n/QK_K; ++ibl) {
+ const uint16_t * sc = (const uint16_t *)iq1m[ibl].scales; // 4 x uint16_t, each containing 4 scales
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ float d = GGML_FP16_TO_FP32(scale.f16);
+ auto qs = iq1m[ibl].qs;
+ auto qh = iq1m[ibl].qh;
+ auto aux = _mm_loadl_epi64((const __m128i *)iq1m[ibl].scales);
+ auto sc16 = _mm256_shuffle_epi8(MM256_SET_M128I(aux, aux), scale_shuffle);
+ sc16 = _mm256_and_si256(sc16, _mm256_set1_epi64x(0x0e0001c000380007));
+ sc16 = _mm256_mullo_epi16(sc16, _mm256_set1_epi64x(0x0001000800400200));
+ helper.vec = _mm256_add_epi8(_mm256_srli_epi16(sc16, 8), _mm256_set1_epi16(1));
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ qx[2*ib64+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid_us[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
+ iq1s_grid_us[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid_us[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
+ qx[2*ib64+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid_us[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
+ iq1s_grid_us[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid_us[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
+ //auto delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
+ // qh[1] & 0x08 ? 0x0909090909090909 : 0x0707070707070707,
+ // qh[0] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
+ // qh[0] & 0x08 ? 0x0909090909090909 : 0x0707070707070707);
+ //auto delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
+ // qh[3] & 0x08 ? 0x0909090909090909 : 0x0707070707070707,
+ // qh[2] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
+ // qh[2] & 0x08 ? 0x0909090909090909 : 0x0707070707070707);
+ auto qh16 = (const uint16_t *)qh;
+ auto delta1 = _mm256_cmpeq_epi64(_mm256_and_si256(_mm256_set1_epi64x(qh16[0]), delta_mask), delta_mask);
+ auto delta2 = _mm256_cmpeq_epi64(_mm256_and_si256(_mm256_set1_epi64x(qh16[1]), delta_mask), delta_mask);
+ delta1 = _mm256_sub_epi8(_mm256_set1_epi8(8), _mm256_or_si256(delta1, _mm256_set1_epi8(1)));
+ delta2 = _mm256_sub_epi8(_mm256_set1_epi8(8), _mm256_or_si256(delta2, _mm256_set1_epi8(1)));
+ qx[2*ib64+0] = _mm256_sub_epi8(_mm256_slli_epi16(qx[2*ib64+0], 3), delta1);
+ qx[2*ib64+1] = _mm256_sub_epi8(_mm256_slli_epi16(qx[2*ib64+1], 3), delta2);
+ qs += 8;
+ qh += 4;
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = _mm256_setzero_si256();
+ for (int j = 0; j < 8; ++j) {
+ auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(qx[j], qx[j]), _mm256_sign_epi8(q8.load_quants(iy, ibl, j), qx[j]));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(p, MM256_SET_M128I(_mm_set1_epi16(helper.val[2*j+1]), _mm_set1_epi16(helper.val[2*j+0]))));
+ }
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, 0.125f*hsum_float_8(acc[iy]));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+template <int nrc_y>
void mul_mat_iq1_s_q8_2_x4(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
Q8<nrc_y, block_q8_2_x4> q8(info);
@@ -1844,6 +1906,11 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
func16 = mul_mat_iq1_s_r4_q8_1<16>;
#endif
break;
+ case GGML_TYPE_IQ1_M:
+ if (ne00%QK_K != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_q8_K, funcs);
+ expected_typeB = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_IQ1_M_R4:
if (ne00%128 != 0) return false;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_r4_q8_0, funcs);
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index fb951e70..31de5c42 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -243,7 +243,7 @@ struct MulMat {
case GGML_TYPE_IQ4_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
- case GGML_TYPE_IQ1_M : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ1_M : return nrc_y >= 999932 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_Q2_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_Q3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
@@ -867,6 +867,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_IQ4_NL_R4:
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ1_M_R4:
case GGML_TYPE_IQ1_BN:
@@ -958,6 +959,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_BN_R4:
case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ1_M_R4:
return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16);