diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-11 15:12:54 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:50 +0300 |
commit | 58756ef03ff3f19a98187395d12af3f19f121f90 (patch) | |
tree | b1165a13cff91f06021e273f943deec0a2588f06 | |
parent | 7501184eb4d5f9cb12a160919f6603d75c6bc529 (diff) |
iqk_mul_mat: cleanup
-rw-r--r-- | iqk_mul_mat.cpp | 163 |
1 files changed, 97 insertions, 66 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index d9583e8b..f7571926 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -15,9 +15,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +#if defined IQK_IMPLEMENT +#undef IQK_IMPLEMENT +#endif + +#if defined __AVX2__ || defined __ARM_FEATURE_DOTPROD +#define IQK_IMPLEMENT +#endif + #include <cstring> #include <type_traits> -#if defined __x86_64__ || defined __aarch64__ + +#if defined IQK_IMPLEMENT #include "ggml-impl.h" #include "ggml-quants.h" @@ -29,22 +38,25 @@ // clang-format off // This matrix - vector and matrix - matrix multiplication implementation -// for k-quants and IQ4_XS makes prompt processing 150-200% faster -// compared to mainline llama.cpp (and llamafile). -// It is AVX2 only for now. +// for k-quants, i-quants, and legacy quants, makes prompt processing +// 150-350% faster (depending on quantization type) compared to mainline llama.cpp. +// It is AVX2 and ARM_NEON only for now. +// There are also implementations for fp16/32 x fp16/32 matrix multiplications +// on AVX2 and fp16 x fp16 on ARM_NEON. // // Main idea is that unpacking the quants and the block scales to -// be ready for dot products with the corresponding Q8_K quants -// takes time. Hence, if we are performing a QX x Q8_K matrix matrix +// be ready for dot products with the corresponding Q8_X quants +// takes time. Hence, if we are performing a QX x Q8_X matrix matrix // multiplication (as needed for prompt processing), we can get // a significant speedup by reusing the unpacked QX quants and scales -// for multiplication with several Q8_K columns. +// for multiplication with several Q8_X columns. +// +// For fp16/fp32 matri multiplications tiling is used to improve +// performance. #include <utility> #include <array> -#endif - #ifdef _MSC_VER #define IQK_NOINLINE __declspec(noinline) #define IQK_ALWAYS_INLINE inline @@ -79,7 +91,6 @@ struct DataInfo { inline void store(int ix, int iy, float result) const { *(dst_row(iy) + ix) = result; - //dst_row(iy)[ix] = result; } inline float * dst_row(int iy) const { if (!row_mapping) return s + (cur_y + iy)*bs; @@ -120,57 +131,11 @@ struct MulMat { funcs[n_left-1](n, vx, bx, info, nrc_x); } } - static bool set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny); + static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny); private: template <typename Dequantizer> static void set_functions(MulMat& m); }; -inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { - const uint16_t * scales = (const uint16_t *)scales8; - const uint32_t a0 = scales[0] | (scales[1] << 16); - const uint32_t a1 = scales[2] | (scales[3] << 16); - const uint32_t a2 = scales[4] | (scales[5] << 16); - aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030); - aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030); - aux32[2] = a1 & 0x3f3f3f3f; - aux32[0] = a0 & 0x3f3f3f3f; -} - -const uint64_t keven_signs[128] = { - 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff, - 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff, - 0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff, - 0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff, - 0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff, - 0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff, - 0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff, - 0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff, - 0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff, - 0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff, - 0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff, - 0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff, - 0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff, - 0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff, - 0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff, - 0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff, - 0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff, - 0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff, - 0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff, - 0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff, - 0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff, - 0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff, - 0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff, - 0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff, - 0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff, - 0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff, - 0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff, - 0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff, - 0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff, - 0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff, - 0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff, - 0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff, -}; - } bool iqk_mul_mat(int task_type, long Nx, long Ny, long ne00, @@ -179,7 +144,7 @@ bool iqk_mul_mat(int task_type, long Nx, long Ny, long ne00, float * C, long stride_C, int ith, int nth) { MulMat mm; - if (!MulMat::set_mul_mat(typeA, typeB, ne00, mm, Ny)) { + if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { return false; } @@ -207,7 +172,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, assert(row_mapping != nullptr); MulMat mm; - if (!MulMat::set_mul_mat(typeA, typeB, ne00, mm, Ny)) { + if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { return false; } auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA)); @@ -221,6 +186,56 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, return true; } +namespace { + +inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { + const uint16_t * scales = (const uint16_t *)scales8; + const uint32_t a0 = scales[0] | (scales[1] << 16); + const uint32_t a1 = scales[2] | (scales[3] << 16); + const uint32_t a2 = scales[4] | (scales[5] << 16); + aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030); + aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030); + aux32[2] = a1 & 0x3f3f3f3f; + aux32[0] = a0 & 0x3f3f3f3f; +} + +const uint64_t keven_signs[128] = { + 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff, + 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff, + 0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff, + 0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff, + 0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff, + 0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff, + 0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff, + 0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff, + 0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff, + 0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff, + 0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff, + 0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff, + 0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff, + 0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff, + 0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff, + 0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff, + 0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff, + 0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff, + 0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff, + 0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff, + 0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff, + 0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff, + 0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff, + 0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff, + 0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff, + 0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff, + 0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff, + 0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff, + 0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff, + 0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff, + 0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff, + 0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff, +}; + +} + #if defined __x86_64__ #if defined HAVE_FANCY_SIMD @@ -2159,6 +2174,8 @@ struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_ inline static int block_size() { return QK4_1; } }; +// float matrices - we handle f16 and f32, but only to f32 result + struct QFBase { #ifdef __AVX512F__ constexpr static int k_step = 16; @@ -2203,8 +2220,6 @@ template <typename Float, int nrc_in> struct QFT final : public QFBase { IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } const Float * y[nrc]; }; -//template <int nrc_y> using QF32 = QFT<float, nrc_y>; -//template <int nrc_y> using QF16 = QFT<ggml_half, nrc_y>; template <typename Qy, typename Qx> IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { @@ -2236,6 +2251,7 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, } for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); } + // This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done // in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in // f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now. @@ -2264,6 +2280,10 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in } } +// +// Tiled Q8_0 x Q8_0 implementation. Not used as the templated legacy quant implementation +// above is faster. Left behind so we remember we tried. +// template <int nrc> struct Q80 { constexpr static int nrc_y = nrc; Q80(const DataInfo& info) { @@ -2413,7 +2433,7 @@ void set_mul_mat_f(MulMat& mm) { #endif } -bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { +bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { (void)Ny; @@ -3929,7 +3949,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { } } -bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { +bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F16) { if (ne00%8) return false; @@ -3939,8 +3959,6 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) m.funcs[2] = mul_mat_f16_f16_T<3>; m.funcs[3] = mul_mat_f16_f16_T<4>; m.funcs[4] = mul_mat_f16_f16_T<5>; - //m.funcs[5] = mul_mat_f16_f16_T<6>; - //m.funcs[6] = mul_mat_f16_f16_T<7>; return true; } @@ -4009,4 +4027,17 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) } -#endif // __x86_64__ or __aarch64__ +#endif // __aarch64__ + +#else // IQK_IMPLEMENT + +bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { + return false; +} + +bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long, + const void *, int, int) { + return false; +} + +#endif |