summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp64
1 files changed, 23 insertions, 41 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 9a37863b..988fbdc2 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -967,30 +967,19 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
prepare(x[i].qs, x[i].qh);
auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales);
auto scales16 = _mm256_cvtepi8_epi16(scales8);
- scales16 = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, x[i].extra, min, shift));
+ auto bs = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, x[i].extra, min, shift));
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
+ auto prod = _mm256_madd_epi16(bs, q8.load_bsums(iy, i));
accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
}
- scales16 = MM256_SET_M128I(scales8, scales8);
- scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1));
- scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2));
- }
- template <typename Q8>
- inline void new_block(int i, const Q8& q8, __m256 * accm, __m512 * scales) {
- d = GGML_FP16_TO_FP32(x[i].d);
- prepare(x[i].qs, x[i].qh);
- auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales);
- auto scales16 = _mm256_cvtepi8_epi16(scales8);
- scales16 = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, x[i].extra, min, shift));
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
- accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
- }
- auto vd = Q8::nrc_y == 1 ? _mm512_set1_ps(d*q8.scale(0, i)) : _mm512_set1_ps(d);
- for (int k = 0; k < 4; ++k) {
- scales[k] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffles[k])))));
- }
+ auto aux_1 = MM256_SET_M128I(_mm256_castsi256_si128(scales16), _mm256_castsi256_si128(scales16));
+ auto aux_2 = MM256_SET_M128I(_mm256_extracti128_si256(scales16, 1), _mm256_extracti128_si256(scales16, 1));
+ auto scales16_1 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_1), aux_1, 1);
+ auto scales16_2 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_2), aux_2, 1);
+ scales[0] = _mm512_shuffle_epi8(scales16_1, shuffles[0]);
+ scales[1] = _mm512_shuffle_epi8(scales16_1, shuffles[1]);
+ scales[2] = _mm512_shuffle_epi8(scales16_2, shuffles[0]);
+ scales[3] = _mm512_shuffle_epi8(scales16_2, shuffles[1]);
}
inline __m512i make_one(__m512i l, __m512i h) const {
auto p = _mm512_shuffle_epi8(values[0], l);
@@ -1040,9 +1029,11 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
__m512i masks[3] = { _mm512_set1_epi8(0x01), _mm512_set1_epi8(0x02), _mm512_set1_epi8(0x03) };
const __m256i min = _mm256_set1_epi16(-128);
const __m256i shift = _mm256_set1_epi16(1);
- const __m128i shuffles[4] = {
- _mm_set_epi64x(0x0303030302020202, 0x0101010100000000), _mm_set_epi64x(0x0707070706060606, 0x0505050504040404),
- _mm_set_epi64x(0x0b0b0b0b0a0a0a0a, 0x0909090908080808), _mm_set_epi64x(0x0f0f0f0f0e0e0e0e, 0x0d0d0d0d0c0c0c0c),
+ const __m512i shuffles[2] = {
+ _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
+ _mm_set1_epi16(0x0100), 0), _mm_set1_epi16(0x0302), 1), _mm_set1_epi16(0x0504), 2), _mm_set1_epi16(0x0706), 3),
+ _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
+ _mm_set1_epi16(0x0908), 0), _mm_set1_epi16(0x0b0a), 1), _mm_set1_epi16(0x0d0c), 2), _mm_set1_epi16(0x0f0e), 3)
};
const __m256i shuffle1 = _mm256_set_epi64x(0x0707070703030303, 0x0606060602020202, 0x0505050501010101, 0x0404040400000000);
const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0b0b0b0b, 0x0e0e0e0e0a0a0a0a, 0x0d0d0d0d09090909, 0x0c0c0c0c08080808);
@@ -1135,7 +1126,7 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D
__m256 accm[nrc_y];
__m512 accd[nrc_y];
- __m512 scales[4];
+ __m512i scales[4];
for (int ix = 0; ix < nrc_x; ++ix) {
@@ -1149,22 +1140,13 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D
deq.new_block(i, q8, accm, scales);
for (int iy = 0; iy < nrc_y; ++iy) {
- const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
- const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
- const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
- const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
- if constexpr (nrc_y == 1) {
- accd[iy] = _mm512_fmadd_ps(scales[0], _mm512_cvtepi32_ps(p1), accd[iy]);
- accd[iy] = _mm512_fmadd_ps(scales[1], _mm512_cvtepi32_ps(p2), accd[iy]);
- accd[iy] = _mm512_fmadd_ps(scales[2], _mm512_cvtepi32_ps(p3), accd[iy]);
- accd[iy] = _mm512_fmadd_ps(scales[3], _mm512_cvtepi32_ps(p4), accd[iy]);
- } else {
- auto d8 = _mm512_set1_ps(q8.scale(iy, i));
- accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[0]), _mm512_cvtepi32_ps(p1), accd[iy]);
- accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[1]), _mm512_cvtepi32_ps(p2), accd[iy]);
- accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[2]), _mm512_cvtepi32_ps(p3), accd[iy]);
- accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[3]), _mm512_cvtepi32_ps(p4), accd[iy]);
- }
+ const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(),
+ p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]);
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
}
}