summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp12
1 files changed, 4 insertions, 8 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index f112bfa9..6b739a1e 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -743,15 +743,11 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
};
struct IQXKScales {
- IQXKScales(uint8_t shift, int8_t min_val) : eshift(_mm_set1_epi8(shift)), min(_mm256_set1_epi16(min_val)) {}
+ IQXKScales(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {}
template <typename Q8>
inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const {
- auto extra128 = _mm_set1_epi16(extra);
- extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask);
- extra128 = _mm_and_si128(extra128, eshift);
- extra128 = _mm_shuffle_epi8(extra128, eshuffle);
- auto scales16 = _mm256_mullo_epi16(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)),
- _mm256_add_epi16(min, _mm256_cvtepi8_epi16(extra128)));
+ auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle));
+ scales16 = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, extra, min, eshift));
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
@@ -760,7 +756,7 @@ struct IQXKScales {
scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1));
scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2));
}
- const __m128i eshift;
+ const __m256i eshift;
const __m256i min;
const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);