diff options
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r-- | iqk_mul_mat.cpp | 28 |
1 files changed, 19 insertions, 9 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 4a2417b4..923829ab 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1326,7 +1326,7 @@ template <int nrc> struct Q8_K64 { }; template <int nrc_y> -static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; Q8_K64<nrc_y> q8(info); __m256 accd[nrc_y]; @@ -1375,17 +1375,28 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn auto v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1), _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1)); - for (int iy = 0; iy < nrc_y; ++iy) { - auto q8_1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), signs[0]); - auto q8_2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), signs[1]); - auto dot1 = _mm256_sign_epi8(q8_1, v1); - auto dot2 = _mm256_sign_epi8(q8_2, v2); + if constexpr (nrc_y == 1) { + auto dot1 = _mm256_sign_epi8(_mm256_sign_epi8(q8.load_quants(0, i, 0), signs[0]), v1); + auto dot2 = _mm256_sign_epi8(_mm256_sign_epi8(q8.load_quants(0, i, 1), signs[1]), v2); #if defined __AVX512VNNI__ && defined __AVX512VL__ auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2); #else - auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_api16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); #endif - accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]); + accd[0] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(0, i)), _mm256_cvtepi32_ps(dot), accd[0]); + } else { + v1 = _mm256_sign_epi8(v1, signs[0]); + v2 = _mm256_sign_epi8(v2, signs[1]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2); +#if defined __AVX512VNNI__ && defined __AVX512VL__ + auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2); +#else + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); +#endif + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]); + } } } @@ -1393,7 +1404,6 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn info.store(ix, iy, scale.f * hsum_float_8(accd[iy])); } - //x += step; } } |