diff options
-rw-r--r-- | ggml/src/iqk/iqk_gemm_1bit.cpp | 67 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 4 |
2 files changed, 70 insertions, 1 deletions
diff --git a/ggml/src/iqk/iqk_gemm_1bit.cpp b/ggml/src/iqk/iqk_gemm_1bit.cpp index 2c0a1bda..ece76f3c 100644 --- a/ggml/src/iqk/iqk_gemm_1bit.cpp +++ b/ggml/src/iqk/iqk_gemm_1bit.cpp @@ -866,6 +866,68 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, } template <int nrc_y> +void mul_mat_iq1_m_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + Q8<nrc_y, block_q8_K> q8(info); + __m256i qx[8]; + __m256 acc[nrc_y] = {}; + auto scale_shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0504050405040504, 0x0302030203020302, 0x0100010001000100); + auto delta_mask = _mm256_set_epi64x(0x8000, 0x0800, 0x0080, 0x0008); + iq1m_scale_t scale; + union { __m256i vec; int16_t val[16]; } helper; + for (int ix = 0; ix < nrc_x; ++ix) { + auto iq1m = (const block_iq1_m *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < n/QK_K; ++ibl) { + const uint16_t * sc = (const uint16_t *)iq1m[ibl].scales; // 4 x uint16_t, each containing 4 scales + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + float d = GGML_FP16_TO_FP32(scale.f16); + auto qs = iq1m[ibl].qs; + auto qh = iq1m[ibl].qh; + auto aux = _mm_loadl_epi64((const __m128i *)iq1m[ibl].scales); + auto sc16 = _mm256_shuffle_epi8(MM256_SET_M128I(aux, aux), scale_shuffle); + sc16 = _mm256_and_si256(sc16, _mm256_set1_epi64x(0x0e0001c000380007)); + sc16 = _mm256_mullo_epi16(sc16, _mm256_set1_epi64x(0x0001000800400200)); + helper.vec = _mm256_add_epi8(_mm256_srli_epi16(sc16, 8), _mm256_set1_epi16(1)); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + qx[2*ib64+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid_us[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)], + iq1s_grid_us[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid_us[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]); + qx[2*ib64+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid_us[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)], + iq1s_grid_us[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid_us[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]); + //auto delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0x0909090909090909 : 0x0707070707070707, + // qh[1] & 0x08 ? 0x0909090909090909 : 0x0707070707070707, + // qh[0] & 0x80 ? 0x0909090909090909 : 0x0707070707070707, + // qh[0] & 0x08 ? 0x0909090909090909 : 0x0707070707070707); + //auto delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0x0909090909090909 : 0x0707070707070707, + // qh[3] & 0x08 ? 0x0909090909090909 : 0x0707070707070707, + // qh[2] & 0x80 ? 0x0909090909090909 : 0x0707070707070707, + // qh[2] & 0x08 ? 0x0909090909090909 : 0x0707070707070707); + auto qh16 = (const uint16_t *)qh; + auto delta1 = _mm256_cmpeq_epi64(_mm256_and_si256(_mm256_set1_epi64x(qh16[0]), delta_mask), delta_mask); + auto delta2 = _mm256_cmpeq_epi64(_mm256_and_si256(_mm256_set1_epi64x(qh16[1]), delta_mask), delta_mask); + delta1 = _mm256_sub_epi8(_mm256_set1_epi8(8), _mm256_or_si256(delta1, _mm256_set1_epi8(1))); + delta2 = _mm256_sub_epi8(_mm256_set1_epi8(8), _mm256_or_si256(delta2, _mm256_set1_epi8(1))); + qx[2*ib64+0] = _mm256_sub_epi8(_mm256_slli_epi16(qx[2*ib64+0], 3), delta1); + qx[2*ib64+1] = _mm256_sub_epi8(_mm256_slli_epi16(qx[2*ib64+1], 3), delta2); + qs += 8; + qh += 4; + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = _mm256_setzero_si256(); + for (int j = 0; j < 8; ++j) { + auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(qx[j], qx[j]), _mm256_sign_epi8(q8.load_quants(iy, ibl, j), qx[j])); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(p, MM256_SET_M128I(_mm_set1_epi16(helper.val[2*j+1]), _mm_set1_epi16(helper.val[2*j+0])))); + } + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, 0.125f*hsum_float_8(acc[iy])); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +template <int nrc_y> void mul_mat_iq1_s_q8_2_x4(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(n%QK_K == 0); Q8<nrc_y, block_q8_2_x4> q8(info); @@ -1844,6 +1906,11 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, func16 = mul_mat_iq1_s_r4_q8_1<16>; #endif break; + case GGML_TYPE_IQ1_M: + if (ne00%QK_K != 0) return false; + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_q8_K, funcs); + expected_typeB = GGML_TYPE_Q8_K; + break; case GGML_TYPE_IQ1_M_R4: if (ne00%128 != 0) return false; IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_r4_q8_0, funcs); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index fb951e70..31de5c42 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -243,7 +243,7 @@ struct MulMat { case GGML_TYPE_IQ4_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; - case GGML_TYPE_IQ1_M : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ1_M : return nrc_y >= 999932 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_Q2_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_Q3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; @@ -867,6 +867,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ4_NL_R4: return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16); case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_IQ1_BN: @@ -958,6 +959,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ2_BN: case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16); |