diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 64 |
1 files changed, 23 insertions, 41 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 9a37863b..988fbdc2 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -967,30 +967,19 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { prepare(x[i].qs, x[i].qh); auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales); auto scales16 = _mm256_cvtepi8_epi16(scales8); - scales16 = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, x[i].extra, min, shift)); + auto bs = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, x[i].extra, min, shift)); for (int iy = 0; iy < Q8::nrc_y; ++iy) { - auto prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i)); + auto prod = _mm256_madd_epi16(bs, q8.load_bsums(iy, i)); accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); } - scales16 = MM256_SET_M128I(scales8, scales8); - scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1)); - scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2)); - } - template <typename Q8> - inline void new_block(int i, const Q8& q8, __m256 * accm, __m512 * scales) { - d = GGML_FP16_TO_FP32(x[i].d); - prepare(x[i].qs, x[i].qh); - auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales); - auto scales16 = _mm256_cvtepi8_epi16(scales8); - scales16 = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, x[i].extra, min, shift)); - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - auto prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i)); - accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); - } - auto vd = Q8::nrc_y == 1 ? _mm512_set1_ps(d*q8.scale(0, i)) : _mm512_set1_ps(d); - for (int k = 0; k < 4; ++k) { - scales[k] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffles[k]))))); - } + auto aux_1 = MM256_SET_M128I(_mm256_castsi256_si128(scales16), _mm256_castsi256_si128(scales16)); + auto aux_2 = MM256_SET_M128I(_mm256_extracti128_si256(scales16, 1), _mm256_extracti128_si256(scales16, 1)); + auto scales16_1 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_1), aux_1, 1); + auto scales16_2 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_2), aux_2, 1); + scales[0] = _mm512_shuffle_epi8(scales16_1, shuffles[0]); + scales[1] = _mm512_shuffle_epi8(scales16_1, shuffles[1]); + scales[2] = _mm512_shuffle_epi8(scales16_2, shuffles[0]); + scales[3] = _mm512_shuffle_epi8(scales16_2, shuffles[1]); } inline __m512i make_one(__m512i l, __m512i h) const { auto p = _mm512_shuffle_epi8(values[0], l); @@ -1040,9 +1029,11 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> { __m512i masks[3] = { _mm512_set1_epi8(0x01), _mm512_set1_epi8(0x02), _mm512_set1_epi8(0x03) }; const __m256i min = _mm256_set1_epi16(-128); const __m256i shift = _mm256_set1_epi16(1); - const __m128i shuffles[4] = { - _mm_set_epi64x(0x0303030302020202, 0x0101010100000000), _mm_set_epi64x(0x0707070706060606, 0x0505050504040404), - _mm_set_epi64x(0x0b0b0b0b0a0a0a0a, 0x0909090908080808), _mm_set_epi64x(0x0f0f0f0f0e0e0e0e, 0x0d0d0d0d0c0c0c0c), + const __m512i shuffles[2] = { + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(), + _mm_set1_epi16(0x0100), 0), _mm_set1_epi16(0x0302), 1), _mm_set1_epi16(0x0504), 2), _mm_set1_epi16(0x0706), 3), + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(), + _mm_set1_epi16(0x0908), 0), _mm_set1_epi16(0x0b0a), 1), _mm_set1_epi16(0x0d0c), 2), _mm_set1_epi16(0x0f0e), 3) }; const __m256i shuffle1 = _mm256_set_epi64x(0x0707070703030303, 0x0606060602020202, 0x0505050501010101, 0x0404040400000000); const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0b0b0b0b, 0x0e0e0e0e0a0a0a0a, 0x0d0d0d0d09090909, 0x0c0c0c0c08080808); @@ -1135,7 +1126,7 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D __m256 accm[nrc_y]; __m512 accd[nrc_y]; - __m512 scales[4]; + __m512i scales[4]; for (int ix = 0; ix < nrc_x; ++ix) { @@ -1149,22 +1140,13 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D deq.new_block(i, q8, accm, scales); for (int iy = 0; iy < nrc_y; ++iy) { - const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0)); - const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1)); - const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2)); - const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3)); - if constexpr (nrc_y == 1) { - accd[iy] = _mm512_fmadd_ps(scales[0], _mm512_cvtepi32_ps(p1), accd[iy]); - accd[iy] = _mm512_fmadd_ps(scales[1], _mm512_cvtepi32_ps(p2), accd[iy]); - accd[iy] = _mm512_fmadd_ps(scales[2], _mm512_cvtepi32_ps(p3), accd[iy]); - accd[iy] = _mm512_fmadd_ps(scales[3], _mm512_cvtepi32_ps(p4), accd[iy]); - } else { - auto d8 = _mm512_set1_ps(q8.scale(iy, i)); - accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[0]), _mm512_cvtepi32_ps(p1), accd[iy]); - accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[1]), _mm512_cvtepi32_ps(p2), accd[iy]); - accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[2]), _mm512_cvtepi32_ps(p3), accd[iy]); - accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[3]), _mm512_cvtepi32_ps(p4), accd[iy]); - } + const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0)); + const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1)); + const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2)); + const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3)); + auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(), + p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } } |