summaryrefslogtreecommitdiff
path: root/iqk_mul_mat.cpp
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-08 08:20:26 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:50 +0300
commit299c7f6e89d2d8c4162be06463b82a07540d5691 (patch)
tree5b8ffe39cf3e4dc8c8a3cf69e7a88a6ecd6bc2be /iqk_mul_mat.cpp
parentf0a52f2fbb9d244956e97b0f24da170fcfb75ed9 (diff)
iqk_mul_mat: use block_q8_0_x4 also for AVX2
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r--iqk_mul_mat.cpp41
1 files changed, 32 insertions, 9 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index ce771100..48bfe0e0 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1749,14 +1749,37 @@ struct UnsignedDot {
template <typename Q8, typename Dot> struct Sum4 {
Dot dot;
inline __m256i compute(const __m256i * qx, const Q8 * y) const {
- const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs));
- const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs));
- const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs));
- const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs));
- const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1
- const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3
- return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3
+ if constexpr (std::is_same_v<Q8, block_q8_0>) {
+ const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y;
+ const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0));
+ const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1));
+ const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2));
+ const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3));
+ const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1
+ const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3
+ return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3
+ } else {
+ const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs));
+ const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs));
+ const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs));
+ const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs));
+ const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1
+ const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3
+ return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3
+ }
+ }
+};
+
+struct ScaleHelperQ8_0 {
+ inline __m128 prepare4(const block_q8_0 * y) {
+ const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y;
+ return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4->d));
+ }
+ inline __m128 prepare4(__m128 other_scales, const block_q8_0 * y) {
+ return _mm_mul_ps(other_scales, prepare4(y));
}
+ template <typename Q> inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); }
+ template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }
};
struct ScaleHelperQ_0 {
@@ -1893,11 +1916,11 @@ void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info
Q8<nrc_y, block_q8_0> q8(info);
int nb = n/Unpacker::block_size();
if (nb%4 == 0) {
- mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>(
+ mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, true>, ScaleHelperQ8_0, block_q8_0, nrc_y>(
nb, vx, bx, info, q8.y, nrc_x
);
} else {
- mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>(
+ mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, false>, ScaleHelperQ8_0, block_q8_0, nrc_y>(
nb, vx, bx, info, q8.y, nrc_x
);
}