summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-05 19:43:08 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:50 +0300
commit2ee56b4f0d079b4a1bd58347b13cb85ac5bd1445 (patch)
tree2cefcd4c9124beccff23ef5707074ed508d5a025
parent0ad646b9f0b96c449a76d41e4d5ebd4ba16ae690 (diff)
iqk_mul_mat: minor improvements
Current performance: | model | size | threads | test | t/s | | ----------------- | ---------: | -------: | ------: | ---------------: | | llama 7B IQ3_S | 2.75 GiB | 16 | pp512 | 100.21 ± 0.32 | | llama 7B IQ3_XXS | 2.41 GiB | 16 | pp512 | 105.25 ± 0.75 | | llama 7B IQ2_M | 2.20 GiB | 16 | pp512 | 117.88 ± 0.15 | | llama 7B IQ2_XS | 1.89 GiB | 16 | pp512 | 136.38 ± 0.24 | | llama 7B IQ2_XXS | 1.73 GiB | 16 | pp512 | 128.47 ± 0.39 | mean: 117.64 | ----------------- | ---------: | -------: | ------: | ---------------: | | llama 7B IQ2_XXS | 1.73 GiB | 8 | tg128 | 23.94 ± 0.04 | | llama 7B IQ2_XS | 1.89 GiB | 8 | tg128 | 23.27 ± 0.03 | | llama 7B IQ2_M | 2.20 GiB | 8 | tg128 | 18.88 ± 0.03 | | llama 7B IQ3_XXS | 2.41 GiB | 8 | tg128 | 19.07 ± 0.04 | | llama 7B IQ3_S | 2.75 GiB | 8 | tg128 | 15.44 ± 0.05 | mean: 20.12
-rw-r--r--iqk_mul_mat.cpp75
1 files changed, 64 insertions, 11 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index e797865d..4e7d27b9 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1192,10 +1192,12 @@ template <typename Bits>
inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {
if (j == 0) {
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
- sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
- sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
- sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));
- sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));
+ auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
+ auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
+ auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
+ auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);
+ sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2));
+ sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4));
#else
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
@@ -1206,10 +1208,12 @@ inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, cons
#endif
} else {
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
- sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
- sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
- sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));
- sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));
+ auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
+ auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
+ auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
+ auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);
+ sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2));
+ sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4));
#else
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
@@ -1221,12 +1225,33 @@ inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, cons
}
}
+inline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) {
+#ifdef HAVE_FANCY_SIMD
+ auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100)
+ : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908);
+ scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);
+ scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)));
+#else
+ set_scales_8(all_scales, j, scales);
+#endif
+}
+
+inline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) {
+#ifdef HAVE_FANCY_SIMD
+ auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100);
+ scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);
+ scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8)));
+#else
+ set_scales_16(all_scales, scales);
+#endif
+}
+
template <typename Dequantizer>
static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_K;
Q8<1> q8(info);
Dequantizer deq(vx, bx);
- __m256i scales[4];
+ __m256i scales[2];
__m256i q8_quants[4];
for (int ix = 0; ix < nrc_x; ++ix) {
@@ -1241,9 +1266,9 @@ static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const Data
for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j, q8, q8_quants);
if constexpr (Dequantizer::num_blocks == 8) {
- set_scales_8(all_scales[0], j, scales);
+ set_scales_8_iq(j, all_scales[0], scales);
} else {
- set_scales_16(all_scales[j], scales);
+ set_scales_16_iq(all_scales[j], scales);
}
multiply_add_1(j, deq.bits, scales, q8_quants, sumi);
}
@@ -1254,6 +1279,32 @@ static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const Data
}
}
+// So, if I uncomment this function and the call to it in mul_mat_qX_K_q8_K_IQ_N() below,
+// PP performance improves by ~2-3% (when we have __AVX512VNNI__ and __AVX512VL__).
+// But TG performance for iq3_xs drops by 35%. Seriously? I mean, c'mon,
+// what does the compilation of mul_mat_qX_K_q8_K_IQ_1 (which gets invoked during TG)
+// have to do with the compilation of mul_mat_qX_K_q8_K_IQ_N (invoked during PP)?
+//template <typename Q8, typename Bits>
+//inline void multiply_add_iq(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
+//#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
+// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0)));
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1)));
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2)));
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3)));
+// }
+//#else
+// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+// const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0)));
+// const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1)));
+// const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2)));
+// const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3)));
+// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
+// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
+// }
+//#endif
+//}
+
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_K;
@@ -1271,6 +1322,7 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
for (int i = 0; i < nb; ++i) {
__m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8];
+ //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
__m256i mins;
float dmin = deq.new_block(i, all_scales, mins);
for (int iy = 0; iy < nrc_y; ++iy) {
@@ -1286,6 +1338,7 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
} else {
set_scales_16(all_scales[j], scales);
}
+ //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
multiply_add(deq.bits, scales, j, i, q8, sumi);
}
for (int iy = 0; iy < nrc_y; ++iy) {