diff options
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r-- | iqk_mul_mat.cpp | 245 |
1 files changed, 87 insertions, 158 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 72147615..1068975e 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -723,125 +723,6 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> { }; -//struct SimpleBitsAVX512 { -// __m512i values[4]; -//}; -// -//struct SignHelperAVX512 { -// inline void sign_2_values(const uint16_t * sign_bits, __m512i * values) const { -// const __mmask64 * mask = (const __mmask64 *)sign_bits; -// values[0] = _mm512_mask_sub_epi8(values[0], mask[0], _mm512_setzero_si512(), values[0]); -// values[1] = _mm512_mask_sub_epi8(values[1], mask[1], _mm512_setzero_si512(), values[1]); -// //auto minus = _mm512_set1_epi8(-1); -// //auto neg_value = _mm512_sub_epi8(_mm512_xor_si512(values[0], minus), minus); -// //values[0] = _mm512_mask_blend_epi8(mask[0], values[0], neg_value); -// //neg_value = _mm512_sub_epi8(_mm512_xor_si512(values[1], minus), minus); -// //values[1] = _mm512_mask_blend_epi8(mask[1], values[1], neg_value); -// } -//}; -// -//struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { -// DequantizerIQ3S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} -// -// constexpr static int num_blocks = 8; -// -// inline __m128i make_scales(int i, float& dd) const { -// dd = GGML_FP16_TO_FP32(x[i].d); -// uint32_t aux32[2]; -// std::memcpy(aux32, x[i].scales, 4); -// aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; -// aux32[0] &= 0x0f0f0f0f; -// auto scales8 = _mm_shuffle_epi8(_mm_loadl_epi64((const __m128i *)aux32), _mm_set1_epi64x(0x0703060205010400)); -// auto scales16 = _mm256_castsi256_si128(_mm256_cvtepi8_epi16(scales8)); -// return _mm_or_si128(_mm_slli_epi16(scales16, 1), _mm_set1_epi16(1)); -// } -// template <typename Q8> -// inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { -// prepare(i); -// auto scales16 = make_scales(i, d); -// scb.accum_mins(scales16, q8, i, -minv*d, accd); -// auto scales256 = MM256_SET_M128I(scales16, scales16); -// auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); -// scales[0] = _mm512_shuffle_epi8(all_scales, shuffles512[0]); -// scales[1] = _mm512_shuffle_epi8(all_scales, shuffles512[1]); -// } -// -// union index_t { -// __m512i vec; -// uint32_t val[16]; -// }; -// -// inline static __m512i make1(const uint8_t * qs, const uint8_t * qh, const __m512i& idx_shift, const __m512i& idx_mask) { -// auto idx_l = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)qs)); -// auto idx_h = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_set1_epi32(qh[0])), _mm256_set1_epi32(qh[1]), 1); -// idx_h = _mm512_and_si512(_mm512_sllv_epi32(idx_h, idx_shift), idx_mask); -// index_t idx; idx.vec = _mm512_or_si512(idx_l, idx_h); -// return _mm512_set_epi32(iq3s_grid[idx.val[15]], iq3s_grid[idx.val[14]], iq3s_grid[idx.val[13]], iq3s_grid[idx.val[12]], -// iq3s_grid[idx.val[11]], iq3s_grid[idx.val[10]], iq3s_grid[idx.val[ 9]], iq3s_grid[idx.val[ 8]], -// iq3s_grid[idx.val[ 7]], iq3s_grid[idx.val[ 6]], iq3s_grid[idx.val[ 5]], iq3s_grid[idx.val[ 4]], -// iq3s_grid[idx.val[ 3]], iq3s_grid[idx.val[ 2]], iq3s_grid[idx.val[ 1]], iq3s_grid[idx.val[ 0]]); -// ////index_t idx1, idx2; -// ////auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs)); -// ////auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask); -// ////idx1.vec = _mm256_or_si256(idx_h, idx_l); -// ////idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs + 8))); -// ////idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask); -// ////idx2.vec = _mm256_or_si256(idx_h, idx_l); -// ////return _mm512_set_epi32(iq3s_grid[idx2.val[7]], iq3s_grid[idx2.val[6]], iq3s_grid[idx2.val[5]], iq3s_grid[idx2.val[4]], -// //// iq3s_grid[idx2.val[3]], iq3s_grid[idx2.val[2]], iq3s_grid[idx2.val[1]], iq3s_grid[idx2.val[0]], -// //// iq3s_grid[idx1.val[7]], iq3s_grid[idx1.val[6]], iq3s_grid[idx1.val[5]], iq3s_grid[idx1.val[4]], -// //// iq3s_grid[idx1.val[3]], iq3s_grid[idx1.val[2]], iq3s_grid[idx1.val[1]], iq3s_grid[idx1.val[0]]); -// //////return _mm512_inserti32x8(value, val, 1); -// //index_t idx; -// //auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs)); -// //auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask); -// //idx.vec = _mm256_or_si256(idx_h, idx_l); -// //auto val = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], -// // iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); -// //auto value = _mm512_inserti32x8(_mm512_setzero_si512(), val, 0); -// //idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs + 8))); -// //idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask); -// //idx.vec = _mm256_or_si256(idx_h, idx_l); -// //val = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], -// // iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); -// //return _mm512_inserti32x8(value, val, 1); -// } -// -// inline void prepare(int i) { -// prepare_unsigned(i); -// auto signs = (const uint16_t *)x[i].signs; -// sh.sign_2_values(signs+0, bits.values+0); -// sh.sign_2_values(signs+8, bits.values+2); -// auto min_value = _mm512_set1_epi8(minv); -// for (int k = 0; k < 4; ++k) bits.values[k] = _mm512_add_epi8(bits.values[k], min_value); -// } -// -// inline void prepare_unsigned(int i) { -// auto qs = x[i].qs; -// auto qh = x[i].qh; -// bits.values[0] = make1(qs+ 0, qh+0, idx_shift, idx_mask); -// bits.values[1] = make1(qs+16, qh+2, idx_shift, idx_mask); -// bits.values[2] = make1(qs+32, qh+4, idx_shift, idx_mask); -// bits.values[3] = make1(qs+48, qh+6, idx_shift, idx_mask); -// } -// -// constexpr static int minv = 16; -// -// SimpleBitsAVX512 bits; -// SignHelperAVX512 sh; -// Scales8KBase scb; -// const __m512i idx_shift = _mm512_set_epi32(1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8); -// const __m512i idx_mask = _mm512_set1_epi32(256); -// //const __m256i min_value = _mm256_set1_epi8(minv); -// const __m512i shuffles512[2] = { -// _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302, -// 0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100), -// _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, -// 0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908) -// }; -// -//}; - template <typename Dequantizer, int nrc_y> static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); @@ -2154,7 +2035,7 @@ struct Q_Unpacker { struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> { Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - inline static int block_size() { return QK4_0; } + inline static int block_size() { return QK8_0; } }; struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> { Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} @@ -2173,22 +2054,6 @@ struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_ inline static int block_size() { return QK4_1; } }; -template <int nrc_y> -void mul_mat_q8_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n%Q8_0_Unpacker::block_size() == 0); - Q8<nrc_y, block_q8_0> q8(info); - int nb = n/Q8_0_Unpacker::block_size(); - if (nb%4 == 0) { - mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>( - nb, vx, bx, info, q8.y, nrc_x - ); - } else { - mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>( - nb, vx, bx, info, q8.y, nrc_x - ); - } -} - template <int nrc> struct QF32 { constexpr static int nrc_y = nrc; QF32(const DataInfo& info) { @@ -2332,8 +2197,75 @@ void mul_mat_f16_f32_T(int n, const void * vx, size_t bx, const DataInfo& info, } //#endif +template <int nrc> struct Q80 { + constexpr static int nrc_y = nrc; + Q80(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy); + } + IQK_ALWAYS_INLINE __m256i load1(int iy, int i) const { return _mm256_loadu_si256((const __m256i *)y[iy][i].qs); } + IQK_ALWAYS_INLINE float scale(int iy, int i) const { return GGML_FP16_TO_FP32(y[iy][i].d); } + + const block_q8_0 * y[nrc_y]; +}; +inline __m256i mul_q80(__m256i x, __m256i y) { + auto ux = _mm256_sign_epi8(x, x); +#ifdef HAVE_FANCY_SIMD + return _mm256_dpbusd_epi32(_mm256_setzero_si256(), ux, _mm256_sign_epi8(y, x)); +#else + return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(ux, _mm256_sign_epi8(y, x))); +#endif +} +template <int nrc_y> +void mul_mat_q80_q80_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK8_0 == 0); + constexpr int k_nx = 4; + int nb = n/QK8_0; + Q80<nrc_y> q8(info); + const block_q8_0 * x[k_nx]; + float ds[k_nx]; + __m256 acc[k_nx*nrc_y]; + __m256i xv[k_nx]; + for (int ix = 0; ix < nrc_x/k_nx; ++ix) { + int ix0 = k_nx*ix; + for (int kx = 0; kx < k_nx; ++kx) { + x[kx] = (const block_q8_0 *)((const char *)vx + (ix0 + kx)*bx); + ds[kx] = GGML_FP16_TO_FP32(x[kx][0].d); + xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][0].qs); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto yv = q8.load1(iy, 0); + float d = q8.scale(iy, 0); + for (int kx = 0; kx < k_nx; ++kx) { + auto dot = mul_q80(yv, xv[kx]); + acc[k_nx*iy + kx] = _mm256_mul_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot)); + } + } + for (int i = 1; i < nb; ++i) { + for (int kx = 0; kx < k_nx; ++kx) { + ds[kx] = GGML_FP16_TO_FP32(x[kx][i].d); + xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][i].qs); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto yv = q8.load1(iy, i); + float d = q8.scale(iy, i); + for (int kx = 0; kx < k_nx; ++kx) { + auto dot = mul_q80(yv, xv[kx]); + acc[k_nx*iy + kx] = _mm256_fmadd_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot), acc[k_nx*iy + kx]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + for (int kx = 0; kx < k_nx; ++kx) info.store(ix0+kx, iy, hsum_float_8(acc[k_nx*iy+kx])); + } + } + int last_x = k_nx*(nrc_x/k_nx); + if (last_x == nrc_x) return; + // TODO: handle remaining rows +} + template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { - if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker>) { + if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> || + std::is_same_v<Dequantizer, Q8_0_Unpacker>) { m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>; m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>; m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>; @@ -2353,27 +2285,6 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { m.funcs[6] = mul_mat_qX_1_q8_1_T<Dequantizer, 7>; m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>; } -// else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S>) { -//#ifdef HAVE_FANCY_SIMD -// m.funcs[0] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 1>; -// m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>; -// m.funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>; -// m.funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>; -// m.funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>; -// m.funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>; -// m.funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>; -// m.funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>; -//#else -// m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>; -// m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>; -// m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>; -// m.funcs[3] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 4>; -// m.funcs[4] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 5>; -// m.funcs[5] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 6>; -// m.funcs[6] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 7>; -// m.funcs[7] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 8>; -//#endif -// } else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> || std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS> || std::is_same_v<Dequantizer, DequantizerIQ2XXS>) { @@ -2440,6 +2351,19 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int row_size_q8 = ggml_row_size(GGML_TYPE_F32, ne00); return true; } + // Using the standard legacy quant template is slightly faster than tiling + // as implemented in mul_mat_q80_q80_T +// if (typeA == GGML_TYPE_Q8_0) { +// for (auto& f : mm.funcs) f = nullptr; +// mm.funcs[0] = mul_mat_q80_q80_T<1>; +// mm.funcs[1] = mul_mat_q80_q80_T<2>; +// mm.funcs[2] = mul_mat_q80_q80_T<3>; +//#ifdef __AVX512F__ +// mm.funcs[3] = mul_mat_q80_q80_T<4>; +//#endif +// row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); +// return true; +// } row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); @@ -2508,6 +2432,11 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int MulMat::set_functions<Q5_1_Unpacker>(mm); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); break; + case GGML_TYPE_Q8_0: + assert (ne00 % QK8_0 == 0); + MulMat::set_functions<Q8_0_Unpacker>(mm); + row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); + break; default: return false; |