diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-09-10 16:21:57 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-10 16:21:57 +0300 |
commit | d17d0c44267bd7d8040626d1006c8377dad4f502 (patch) | |
tree | f21a9e4bf25227a7e37933c25fc44ce2d2fa2434 | |
parent | a1f7a03f500451be80ec4aeae44665c58cde311f (diff) |
iq2_tn: slightly better performance on AVX2 (#47)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 75 |
1 files changed, 48 insertions, 27 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index b5e3cba3..424cba85 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1763,42 +1763,63 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da const int nb = n/QK_K; Q8<nrc_y> q8(info); - DequantizerIQ2TN deq(vx, bx); + DequantizerIQ2TN deq1(vx, bx), deq2(vx, bx); __m256 accd[nrc_y]; const auto m1 = _mm256_set1_epi16(1); for (int ix = 0; ix < nrc_x; ++ix) { - deq.new_row(ix); + deq1.new_row(ix); + deq2.new_row(ix); for (int i = 0; i < nb; ++i) { - __m256i sumi[nrc_y]; - deq.new_block(i); + deq1.new_block(i); - deq.prepare(i, 0); - for (int iy = 0; iy < nrc_y; ++iy) { - sumi[iy] = _mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)), - _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1))); - sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 2)), - _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 3))), sumi[iy]); + if constexpr (nrc_y == 1) { + deq1.prepare(i, 0); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(0, i, 0)), + _mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(0, i, 1))); + sumi1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(0, i, 2)), + _mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(0, i, 3))), sumi1); + + deq2.prepare(i, 1); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(0, i, 4)), + _mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(0, i, 5))); + sumi2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(0, i, 6)), + _mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(0, i, 7))), sumi2); + auto sumi = _mm256_add_epi16(sumi2, _mm256_sub_epi16(sumi1, q8.load_bsums(0, i))); + auto vd = _mm256_set1_ps(deq1.d*q8.scale(0, i)); + auto sf = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi)); + accd[0] = i > 0 ? _mm256_fmadd_ps(vd, sf, accd[0]) : _mm256_mul_ps(vd, sf); } - deq.prepare(i, 1); - for (int iy = 0; iy < nrc_y; ++iy) { - sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 4)), - _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 5))), sumi[iy]); - sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 6)), - _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 7))), sumi[iy]); - sumi[iy] = _mm256_sub_epi16(sumi[iy], q8.load_bsums(iy, i)); - } - if (i > 0) { - for (int iy = 0; iy < nrc_y; ++iy) { - accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]); - } - } else { + else { + + deq1.prepare(i, 0); deq2.prepare(i, 1); for (int iy = 0; iy < nrc_y; ++iy) { - accd[iy] = _mm256_mul_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy]))); + auto vd = _mm256_set1_ps(deq1.d*q8.scale(iy, i)); + auto sumi = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(iy, i, 0)), + _mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(iy, i, 1))); + sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(iy, i, 2)), + _mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(iy, i, 3))), sumi); + sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(iy, i, 4)), + _mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(iy, i, 5))), sumi); + sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(iy, i, 6)), + _mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(iy, i, 7))), sumi); + sumi = _mm256_sub_epi16(sumi, q8.load_bsums(iy, i)); + + //auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(iy, i, 0)), + // _mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(iy, i, 1))); + //auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(iy, i, 2)), + // _mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(iy, i, 3))); + //sumi1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(iy, i, 4)), + // _mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(iy, i, 5))), sumi1); + //sumi2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(iy, i, 6)), + // _mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(iy, i, 7))), sumi2); + //auto sumi = _mm256_add_epi16(sumi2, _mm256_sub_epi16(sumi1, q8.load_bsums(iy, i))); + auto sf = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi)); + accd[iy] = i > 0 ? _mm256_fmadd_ps(vd, sf, accd[iy]) : _mm256_mul_ps(vd, sf); } } @@ -3671,9 +3692,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[2] = mul_mat_iq2tn_q8_K<3>; mm.funcs[3] = mul_mat_iq2tn_q8_K<4>; mm.funcs[4] = mul_mat_iq2tn_q8_K<5>; - //mm.funcs[5] = mul_mat_iq2tn_q8_K<6>; - //mm.funcs[6] = mul_mat_iq2tn_q8_K<7>; - //mm.funcs[7] = mul_mat_iq2tn_q8_K<8>; + mm.funcs[5] = mul_mat_iq2tn_q8_K<6>; + mm.funcs[6] = mul_mat_iq2tn_q8_K<7>; + mm.funcs[7] = mul_mat_iq2tn_q8_K<8>; #endif break; case GGML_TYPE_Q3_K: |