diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-05 19:43:08 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:50 +0300 |
commit | 2ee56b4f0d079b4a1bd58347b13cb85ac5bd1445 (patch) | |
tree | 2cefcd4c9124beccff23ef5707074ed508d5a025 | |
parent | 0ad646b9f0b96c449a76d41e4d5ebd4ba16ae690 (diff) |
iqk_mul_mat: minor improvements
Current performance:
| model | size | threads | test | t/s |
| ----------------- | ---------: | -------: | ------: | ---------------: |
| llama 7B IQ3_S | 2.75 GiB | 16 | pp512 | 100.21 ± 0.32 |
| llama 7B IQ3_XXS | 2.41 GiB | 16 | pp512 | 105.25 ± 0.75 |
| llama 7B IQ2_M | 2.20 GiB | 16 | pp512 | 117.88 ± 0.15 |
| llama 7B IQ2_XS | 1.89 GiB | 16 | pp512 | 136.38 ± 0.24 |
| llama 7B IQ2_XXS | 1.73 GiB | 16 | pp512 | 128.47 ± 0.39 |
mean: 117.64
| ----------------- | ---------: | -------: | ------: | ---------------: |
| llama 7B IQ2_XXS | 1.73 GiB | 8 | tg128 | 23.94 ± 0.04 |
| llama 7B IQ2_XS | 1.89 GiB | 8 | tg128 | 23.27 ± 0.03 |
| llama 7B IQ2_M | 2.20 GiB | 8 | tg128 | 18.88 ± 0.03 |
| llama 7B IQ3_XXS | 2.41 GiB | 8 | tg128 | 19.07 ± 0.04 |
| llama 7B IQ3_S | 2.75 GiB | 8 | tg128 | 15.44 ± 0.05 |
mean: 20.12
-rw-r--r-- | iqk_mul_mat.cpp | 75 |
1 files changed, 64 insertions, 11 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index e797865d..4e7d27b9 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1192,10 +1192,12 @@ template <typename Bits> inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { if (j == 0) { #if defined(__AVX512VNNI__) && defined(__AVX512VL__) - sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0])); - sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1])); - sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2])); - sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3])); + auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); + auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); + auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); + auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]); + sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2)); + sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4)); #else const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0])); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1])); @@ -1206,10 +1208,12 @@ inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, cons #endif } else { #if defined(__AVX512VNNI__) && defined(__AVX512VL__) - sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0])); - sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1])); - sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2])); - sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3])); + auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); + auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); + auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); + auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]); + sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2)); + sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4)); #else const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0])); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1])); @@ -1221,12 +1225,33 @@ inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, cons } } +inline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) { +#ifdef HAVE_FANCY_SIMD + auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100) + : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908); + scales[0] = _mm256_shuffle_epi8(all_scales, shuffle); + scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4))); +#else + set_scales_8(all_scales, j, scales); +#endif +} + +inline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) { +#ifdef HAVE_FANCY_SIMD + auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100); + scales[0] = _mm256_shuffle_epi8(all_scales, shuffle); + scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8))); +#else + set_scales_16(all_scales, scales); +#endif +} + template <typename Dequantizer> static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_K; Q8<1> q8(info); Dequantizer deq(vx, bx); - __m256i scales[4]; + __m256i scales[2]; __m256i q8_quants[4]; for (int ix = 0; ix < nrc_x; ++ix) { @@ -1241,9 +1266,9 @@ static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const Data for (int j = 0; j < QK_K/128; ++j) { deq.prepare(i, j, q8, q8_quants); if constexpr (Dequantizer::num_blocks == 8) { - set_scales_8(all_scales[0], j, scales); + set_scales_8_iq(j, all_scales[0], scales); } else { - set_scales_16(all_scales[j], scales); + set_scales_16_iq(all_scales[j], scales); } multiply_add_1(j, deq.bits, scales, q8_quants, sumi); } @@ -1254,6 +1279,32 @@ static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const Data } } +// So, if I uncomment this function and the call to it in mul_mat_qX_K_q8_K_IQ_N() below, +// PP performance improves by ~2-3% (when we have __AVX512VNNI__ and __AVX512VL__). +// But TG performance for iq3_xs drops by 35%. Seriously? I mean, c'mon, +// what does the compilation of mul_mat_qX_K_q8_K_IQ_1 (which gets invoked during TG) +// have to do with the compilation of mul_mat_qX_K_q8_K_IQ_N (invoked during PP)? +//template <typename Q8, typename Bits> +//inline void multiply_add_iq(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { +//#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +// for (int iy = 0; iy < Q8::nrc_y; ++iy) { +// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0))); +// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1))); +// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2))); +// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3))); +// } +//#else +// for (int iy = 0; iy < Q8::nrc_y; ++iy) { +// const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0))); +// const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1))); +// const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2))); +// const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3))); +// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3)); +// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4)); +// } +//#endif +//} + template <typename Dequantizer, int nrc_y> static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_K; @@ -1271,6 +1322,7 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data for (int i = 0; i < nb; ++i) { __m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8]; + //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256(); __m256i mins; float dmin = deq.new_block(i, all_scales, mins); for (int iy = 0; iy < nrc_y; ++iy) { @@ -1286,6 +1338,7 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data } else { set_scales_16(all_scales[j], scales); } + //multiply_add_iq(deq.bits, scales, j, i, q8, sumi); multiply_add(deq.bits, scales, j, i, q8, sumi); } for (int iy = 0; iy < nrc_y; ++iy) { |