summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--iqk_mul_mat.cpp245
1 files changed, 87 insertions, 158 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 72147615..1068975e 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -723,125 +723,6 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
};
-//struct SimpleBitsAVX512 {
-// __m512i values[4];
-//};
-//
-//struct SignHelperAVX512 {
-// inline void sign_2_values(const uint16_t * sign_bits, __m512i * values) const {
-// const __mmask64 * mask = (const __mmask64 *)sign_bits;
-// values[0] = _mm512_mask_sub_epi8(values[0], mask[0], _mm512_setzero_si512(), values[0]);
-// values[1] = _mm512_mask_sub_epi8(values[1], mask[1], _mm512_setzero_si512(), values[1]);
-// //auto minus = _mm512_set1_epi8(-1);
-// //auto neg_value = _mm512_sub_epi8(_mm512_xor_si512(values[0], minus), minus);
-// //values[0] = _mm512_mask_blend_epi8(mask[0], values[0], neg_value);
-// //neg_value = _mm512_sub_epi8(_mm512_xor_si512(values[1], minus), minus);
-// //values[1] = _mm512_mask_blend_epi8(mask[1], values[1], neg_value);
-// }
-//};
-//
-//struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
-// DequantizerIQ3S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
-//
-// constexpr static int num_blocks = 8;
-//
-// inline __m128i make_scales(int i, float& dd) const {
-// dd = GGML_FP16_TO_FP32(x[i].d);
-// uint32_t aux32[2];
-// std::memcpy(aux32, x[i].scales, 4);
-// aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
-// aux32[0] &= 0x0f0f0f0f;
-// auto scales8 = _mm_shuffle_epi8(_mm_loadl_epi64((const __m128i *)aux32), _mm_set1_epi64x(0x0703060205010400));
-// auto scales16 = _mm256_castsi256_si128(_mm256_cvtepi8_epi16(scales8));
-// return _mm_or_si128(_mm_slli_epi16(scales16, 1), _mm_set1_epi16(1));
-// }
-// template <typename Q8>
-// inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
-// prepare(i);
-// auto scales16 = make_scales(i, d);
-// scb.accum_mins(scales16, q8, i, -minv*d, accd);
-// auto scales256 = MM256_SET_M128I(scales16, scales16);
-// auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
-// scales[0] = _mm512_shuffle_epi8(all_scales, shuffles512[0]);
-// scales[1] = _mm512_shuffle_epi8(all_scales, shuffles512[1]);
-// }
-//
-// union index_t {
-// __m512i vec;
-// uint32_t val[16];
-// };
-//
-// inline static __m512i make1(const uint8_t * qs, const uint8_t * qh, const __m512i& idx_shift, const __m512i& idx_mask) {
-// auto idx_l = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)qs));
-// auto idx_h = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_set1_epi32(qh[0])), _mm256_set1_epi32(qh[1]), 1);
-// idx_h = _mm512_and_si512(_mm512_sllv_epi32(idx_h, idx_shift), idx_mask);
-// index_t idx; idx.vec = _mm512_or_si512(idx_l, idx_h);
-// return _mm512_set_epi32(iq3s_grid[idx.val[15]], iq3s_grid[idx.val[14]], iq3s_grid[idx.val[13]], iq3s_grid[idx.val[12]],
-// iq3s_grid[idx.val[11]], iq3s_grid[idx.val[10]], iq3s_grid[idx.val[ 9]], iq3s_grid[idx.val[ 8]],
-// iq3s_grid[idx.val[ 7]], iq3s_grid[idx.val[ 6]], iq3s_grid[idx.val[ 5]], iq3s_grid[idx.val[ 4]],
-// iq3s_grid[idx.val[ 3]], iq3s_grid[idx.val[ 2]], iq3s_grid[idx.val[ 1]], iq3s_grid[idx.val[ 0]]);
-// ////index_t idx1, idx2;
-// ////auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
-// ////auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask);
-// ////idx1.vec = _mm256_or_si256(idx_h, idx_l);
-// ////idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs + 8)));
-// ////idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask);
-// ////idx2.vec = _mm256_or_si256(idx_h, idx_l);
-// ////return _mm512_set_epi32(iq3s_grid[idx2.val[7]], iq3s_grid[idx2.val[6]], iq3s_grid[idx2.val[5]], iq3s_grid[idx2.val[4]],
-// //// iq3s_grid[idx2.val[3]], iq3s_grid[idx2.val[2]], iq3s_grid[idx2.val[1]], iq3s_grid[idx2.val[0]],
-// //// iq3s_grid[idx1.val[7]], iq3s_grid[idx1.val[6]], iq3s_grid[idx1.val[5]], iq3s_grid[idx1.val[4]],
-// //// iq3s_grid[idx1.val[3]], iq3s_grid[idx1.val[2]], iq3s_grid[idx1.val[1]], iq3s_grid[idx1.val[0]]);
-// //////return _mm512_inserti32x8(value, val, 1);
-// //index_t idx;
-// //auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
-// //auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask);
-// //idx.vec = _mm256_or_si256(idx_h, idx_l);
-// //auto val = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
-// // iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
-// //auto value = _mm512_inserti32x8(_mm512_setzero_si512(), val, 0);
-// //idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs + 8)));
-// //idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask);
-// //idx.vec = _mm256_or_si256(idx_h, idx_l);
-// //val = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
-// // iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
-// //return _mm512_inserti32x8(value, val, 1);
-// }
-//
-// inline void prepare(int i) {
-// prepare_unsigned(i);
-// auto signs = (const uint16_t *)x[i].signs;
-// sh.sign_2_values(signs+0, bits.values+0);
-// sh.sign_2_values(signs+8, bits.values+2);
-// auto min_value = _mm512_set1_epi8(minv);
-// for (int k = 0; k < 4; ++k) bits.values[k] = _mm512_add_epi8(bits.values[k], min_value);
-// }
-//
-// inline void prepare_unsigned(int i) {
-// auto qs = x[i].qs;
-// auto qh = x[i].qh;
-// bits.values[0] = make1(qs+ 0, qh+0, idx_shift, idx_mask);
-// bits.values[1] = make1(qs+16, qh+2, idx_shift, idx_mask);
-// bits.values[2] = make1(qs+32, qh+4, idx_shift, idx_mask);
-// bits.values[3] = make1(qs+48, qh+6, idx_shift, idx_mask);
-// }
-//
-// constexpr static int minv = 16;
-//
-// SimpleBitsAVX512 bits;
-// SignHelperAVX512 sh;
-// Scales8KBase scb;
-// const __m512i idx_shift = _mm512_set_epi32(1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8);
-// const __m512i idx_mask = _mm512_set1_epi32(256);
-// //const __m256i min_value = _mm256_set1_epi8(minv);
-// const __m512i shuffles512[2] = {
-// _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302,
-// 0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100),
-// _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a,
-// 0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
-// };
-//
-//};
-
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
@@ -2154,7 +2035,7 @@ struct Q_Unpacker {
struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
- inline static int block_size() { return QK4_0; }
+ inline static int block_size() { return QK8_0; }
};
struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {
Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
@@ -2173,22 +2054,6 @@ struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_
inline static int block_size() { return QK4_1; }
};
-template <int nrc_y>
-void mul_mat_q8_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n%Q8_0_Unpacker::block_size() == 0);
- Q8<nrc_y, block_q8_0> q8(info);
- int nb = n/Q8_0_Unpacker::block_size();
- if (nb%4 == 0) {
- mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>(
- nb, vx, bx, info, q8.y, nrc_x
- );
- } else {
- mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>(
- nb, vx, bx, info, q8.y, nrc_x
- );
- }
-}
-
template <int nrc> struct QF32 {
constexpr static int nrc_y = nrc;
QF32(const DataInfo& info) {
@@ -2332,8 +2197,75 @@ void mul_mat_f16_f32_T(int n, const void * vx, size_t bx, const DataInfo& info,
}
//#endif
+template <int nrc> struct Q80 {
+ constexpr static int nrc_y = nrc;
+ Q80(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);
+ }
+ IQK_ALWAYS_INLINE __m256i load1(int iy, int i) const { return _mm256_loadu_si256((const __m256i *)y[iy][i].qs); }
+ IQK_ALWAYS_INLINE float scale(int iy, int i) const { return GGML_FP16_TO_FP32(y[iy][i].d); }
+
+ const block_q8_0 * y[nrc_y];
+};
+inline __m256i mul_q80(__m256i x, __m256i y) {
+ auto ux = _mm256_sign_epi8(x, x);
+#ifdef HAVE_FANCY_SIMD
+ return _mm256_dpbusd_epi32(_mm256_setzero_si256(), ux, _mm256_sign_epi8(y, x));
+#else
+ return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(ux, _mm256_sign_epi8(y, x)));
+#endif
+}
+template <int nrc_y>
+void mul_mat_q80_q80_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK8_0 == 0);
+ constexpr int k_nx = 4;
+ int nb = n/QK8_0;
+ Q80<nrc_y> q8(info);
+ const block_q8_0 * x[k_nx];
+ float ds[k_nx];
+ __m256 acc[k_nx*nrc_y];
+ __m256i xv[k_nx];
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ int ix0 = k_nx*ix;
+ for (int kx = 0; kx < k_nx; ++kx) {
+ x[kx] = (const block_q8_0 *)((const char *)vx + (ix0 + kx)*bx);
+ ds[kx] = GGML_FP16_TO_FP32(x[kx][0].d);
+ xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][0].qs);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto yv = q8.load1(iy, 0);
+ float d = q8.scale(iy, 0);
+ for (int kx = 0; kx < k_nx; ++kx) {
+ auto dot = mul_q80(yv, xv[kx]);
+ acc[k_nx*iy + kx] = _mm256_mul_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot));
+ }
+ }
+ for (int i = 1; i < nb; ++i) {
+ for (int kx = 0; kx < k_nx; ++kx) {
+ ds[kx] = GGML_FP16_TO_FP32(x[kx][i].d);
+ xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][i].qs);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto yv = q8.load1(iy, i);
+ float d = q8.scale(iy, i);
+ for (int kx = 0; kx < k_nx; ++kx) {
+ auto dot = mul_q80(yv, xv[kx]);
+ acc[k_nx*iy + kx] = _mm256_fmadd_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot), acc[k_nx*iy + kx]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ for (int kx = 0; kx < k_nx; ++kx) info.store(ix0+kx, iy, hsum_float_8(acc[k_nx*iy+kx]));
+ }
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ // TODO: handle remaining rows
+}
+
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
- if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker>) {
+ if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
+ std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;
@@ -2353,27 +2285,6 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
m.funcs[6] = mul_mat_qX_1_q8_1_T<Dequantizer, 7>;
m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;
}
-// else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S>) {
-//#ifdef HAVE_FANCY_SIMD
-// m.funcs[0] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 1>;
-// m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
-// m.funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;
-// m.funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;
-// m.funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;
-// m.funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;
-// m.funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;
-// m.funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;
-//#else
-// m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
-// m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
-// m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
-// m.funcs[3] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 4>;
-// m.funcs[4] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 5>;
-// m.funcs[5] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 6>;
-// m.funcs[6] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 7>;
-// m.funcs[7] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 8>;
-//#endif
-// }
else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> ||
std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS> ||
std::is_same_v<Dequantizer, DequantizerIQ2XXS>) {
@@ -2440,6 +2351,19 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
row_size_q8 = ggml_row_size(GGML_TYPE_F32, ne00);
return true;
}
+ // Using the standard legacy quant template is slightly faster than tiling
+ // as implemented in mul_mat_q80_q80_T
+// if (typeA == GGML_TYPE_Q8_0) {
+// for (auto& f : mm.funcs) f = nullptr;
+// mm.funcs[0] = mul_mat_q80_q80_T<1>;
+// mm.funcs[1] = mul_mat_q80_q80_T<2>;
+// mm.funcs[2] = mul_mat_q80_q80_T<3>;
+//#ifdef __AVX512F__
+// mm.funcs[3] = mul_mat_q80_q80_T<4>;
+//#endif
+// row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
+// return true;
+// }
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);
@@ -2508,6 +2432,11 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
MulMat::set_functions<Q5_1_Unpacker>(mm);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);
break;
+ case GGML_TYPE_Q8_0:
+ assert (ne00 % QK8_0 == 0);
+ MulMat::set_functions<Q8_0_Unpacker>(mm);
+ row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
+ break;
default:
return false;