summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-02-20 12:41:45 +0200
committerGitHub <noreply@github.com>2025-02-20 12:41:45 +0200
commit498a582919f3955fee9ba4239d5f7a298a42425d (patch)
tree64527c69ba859a2b1104b8688eb9d2ff3ea8468c
parenta0ebfdd661a2ccb2700b0e36cfc10ca1a2b4de98 (diff)
Optimized GEMM/GEMV for IQ1_S (#212)
* Adding iq1_s to iqk_mul_mat (Zen4) * iq1_s: slightly better on Zen4 * iq1_s: AVX2 * iq1s: NEON --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp160
1 files changed, 160 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 3bfded73..33e0a4a7 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -3666,6 +3666,85 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
}
}
+// sum[ qy_i * ls_k * (qx_i - 1+/-delta_k)]
+// = sum[qy_i * qx_i * ls_k] - 1/8*sum[qy_i * ls_k * (8+/-o_k)]
+// = 1/8 * ( sum[qy_i * qx_i * 8*ls+k] - sum[qy_i * ls_k * (8+/-o_k)] )
+
+template <int nrc_y>
+static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ __m256i qx[8];
+ __m256i scales[4];
+ __m256 acc[nrc_y] = {};
+ auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000
+ __m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < n/QK_K; ++ibl) {
+ float d = GGML_FP16_TO_FP32(iq1s[ibl].d);
+ auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh);
+ auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7));
+ scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1));
+#ifdef HAVE_FANCY_SIMD
+ auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask);
+ auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9));
+#else
+ auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask);
+ auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7)));
+#endif
+ deltas128 = _mm_mullo_epi16(scales128, deltas128);
+ scales128 = _mm_slli_epi16(scales128, 3);
+ auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128);
+ auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128);
+ auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7
+ auto all_scales = MM256_SET_M128I(scales128, scales128);
+ auto shuffle = shuffle0;
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ scales[ib64] = _mm256_shuffle_epi8(all_scales, shuffle);
+ shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4));
+ }
+ const uint8_t * qs = iq1s[ibl].qs;
+ const uint16_t * qh = iq1s[ibl].qh;
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+ qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)],
+ iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
+ qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)],
+ iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
+ qs += 8;
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+ auto sumi = _mm256_setzero_si256();
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ auto qy1 = q8.load_quants(iy, ibl, 2*ib64+0);
+ auto qy2 = q8.load_quants(iy, ibl, 2*ib64+1);
+#ifdef HAVE_FANCY_SIMD
+ auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+0], qy1);
+ auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+1], qy2);
+ sumi = _mm256_dpwssd_epi32(sumi, scales[ib64], _mm256_packs_epi32(dot1, dot2));
+#else
+ auto dot1 = _mm256_maddubs_epi16(qx[2*ib64+0], qy1);
+ auto dot2 = _mm256_maddubs_epi16(qx[2*ib64+1], qy2);
+ auto dot = _mm256_add_epi16(_mm256_unpacklo_epi64(dot1, dot2), _mm256_unpackhi_epi64(dot1, dot2));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(scales[ib64], dot));
+#endif
+ }
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas);
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(bsums, deltas));
+#endif
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, 0.125f*hsum_float_8(acc[iy]));
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
template <int nrc_y>
static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -9473,6 +9552,20 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[7] = mul_mat_q8_0_r8_q8_1<8>;
expected_typeB = GGML_TYPE_Q8_1_X4;
break;
+ case GGML_TYPE_IQ1_S:
+ mm.funcs[0] = mul_mat_iq1_s_q8_K<1>;
+ mm.funcs[1] = mul_mat_iq1_s_q8_K<2>;
+ mm.funcs[2] = mul_mat_iq1_s_q8_K<3>;
+ mm.funcs[3] = mul_mat_iq1_s_q8_K<4>;
+ mm.funcs[4] = mul_mat_iq1_s_q8_K<5>;
+ mm.funcs[5] = mul_mat_iq1_s_q8_K<6>;
+ mm.funcs[6] = mul_mat_iq1_s_q8_K<7>;
+ mm.funcs[7] = mul_mat_iq1_s_q8_K<8>;
+#ifdef HAVE_FANCY_SIMD
+ mm.func16 = mul_mat_iq1_s_q8_K<16>;
+#endif
+ expected_typeB = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_IQ1_S_R4:
assert (ne00 % QK4_NL == 0);
mm.funcs[0] = mul_mat_iq1_s_r4_q8_1<1>;
@@ -12514,6 +12607,68 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
}
template <int nrc_y>
+static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int8x16_t qx[16];
+ int32x4_t scales[2];
+ int16x4_t deltas[2];
+ float32x4_t acc[nrc_y] = {};
+ auto delta_mask = vdupq_n_u16(0x8000);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < n/QK_K; ++ibl) {
+ float d = GGML_FP16_TO_FP32(iq1s[ibl].d);
+ auto qhb = vld1q_u16(iq1s[ibl].qh);
+ auto scales128 = vandq_u16(vshrq_n_u16(qhb, 12), vdupq_n_u16(7));
+ scales128 = vaddq_u16(vshlq_n_u16(scales128, 1), vdupq_n_u16(1));
+ auto mask = vceqq_u16(vandq_u16(qhb, delta_mask), delta_mask);
+ // Note: we explicitely assume IQ1S_DELTA = 0.125
+ auto deltas128 = vsubq_s16(vbicq_s16(scales128, mask), vandq_s16(scales128, mask));
+ //auto deltas128 = vorrq_s16(vandq_s16(vdupq_n_s16(-1), mask), vbicq_s16(vdupq_n_s16(1), mask));
+ //deltas128 = vmulq_s16(scales128, deltas128);
+ scales128 = vshlq_n_u16(scales128, 3);
+ auto qs = iq1s[ibl].qs;
+ auto qh = iq1s[ibl].qh;
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ qx[4*ib64+0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[0] | ((qh[2*ib64+0] << 8) & 0x700)], iq1s_grid[qs[1] | ((qh[2*ib64+0] << 5) & 0x700)]});
+ qx[4*ib64+1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[2] | ((qh[2*ib64+0] << 2) & 0x700)], iq1s_grid[qs[3] | ((qh[2*ib64+0] >> 1) & 0x700)]});
+ qx[4*ib64+2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[4] | ((qh[2*ib64+1] << 8) & 0x700)], iq1s_grid[qs[5] | ((qh[2*ib64+1] << 5) & 0x700)]});
+ qx[4*ib64+3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[6] | ((qh[2*ib64+1] << 2) & 0x700)], iq1s_grid[qs[7] | ((qh[2*ib64+1] >> 1) & 0x700)]});
+ qs += 8;
+ }
+ scales[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales128)));
+ scales[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales128)));
+ deltas[0] = vget_low_s16 (deltas128);
+ deltas[1] = vget_high_s16(deltas128);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums8(iy, ibl);
+ auto sumi = vdupq_n_s32(0);
+ sumi = vmlal_s16(sumi, deltas[0], vget_low_s16 (bsums));
+ sumi = vmlal_s16(sumi, deltas[1], vget_high_s16(bsums));
+ for (int k = 0; k < QK_K/128; ++k) {
+ auto qy = q8.load_quants_64(iy, ibl, 2*k+0);
+ auto dot1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+0], qy.val[0]), qx[8*k+1], qy.val[1]);
+ auto dot2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+2], qy.val[2]), qx[8*k+3], qy.val[3]);
+ auto dot12 = vpaddq_s32(dot1, dot2);
+ qy = q8.load_quants_64(iy, ibl, 2*k+1);
+ auto dot3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+4], qy.val[0]), qx[8*k+5], qy.val[1]);
+ auto dot4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+6], qy.val[2]), qx[8*k+7], qy.val[3]);
+ auto dot34 = vpaddq_s32(dot3, dot4);
+ auto dot = vpaddq_s32(dot12, dot34);
+ sumi = vmlaq_s32(sumi, dot, scales[k]);
+ }
+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy]));
+ acc[iy] = vdupq_n_f32(0);
+ }
+ }
+}
+
+template <int nrc_y>
static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K128> q8(info);
@@ -14327,6 +14482,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
m.func16 = mul_mat_iq2_s_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
+ case GGML_TYPE_IQ1_S:
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_q8_K);
+ m.func16 = mul_mat_iq1_s_q8_K<16>;
+ expected_Btype = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_IQ1_S_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1);
m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1;