summaryrefslogtreecommitdiff
path: root/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r--iqk_mul_mat.cpp28
1 files changed, 19 insertions, 9 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 4a2417b4..923829ab 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1326,7 +1326,7 @@ template <int nrc> struct Q8_K64 {
};
template <int nrc_y>
-static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
__m256 accd[nrc_y];
@@ -1375,17 +1375,28 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
auto v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1),
_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto q8_1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), signs[0]);
- auto q8_2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), signs[1]);
- auto dot1 = _mm256_sign_epi8(q8_1, v1);
- auto dot2 = _mm256_sign_epi8(q8_2, v2);
+ if constexpr (nrc_y == 1) {
+ auto dot1 = _mm256_sign_epi8(_mm256_sign_epi8(q8.load_quants(0, i, 0), signs[0]), v1);
+ auto dot2 = _mm256_sign_epi8(_mm256_sign_epi8(q8.load_quants(0, i, 1), signs[1]), v2);
#if defined __AVX512VNNI__ && defined __AVX512VL__
auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
#else
- auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_api16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
#endif
- accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]);
+ accd[0] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(0, i)), _mm256_cvtepi32_ps(dot), accd[0]);
+ } else {
+ v1 = _mm256_sign_epi8(v1, signs[0]);
+ v2 = _mm256_sign_epi8(v2, signs[1]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2);
+#if defined __AVX512VNNI__ && defined __AVX512VL__
+ auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
+#else
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
+#endif
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]);
+ }
}
}
@@ -1393,7 +1404,6 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
info.store(ix, iy, scale.f * hsum_float_8(accd[iy]));
}
- //x += step;
}
}