diff options
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r-- | iqk_mul_mat.cpp | 72 |
1 files changed, 31 insertions, 41 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 5f01a610..d9583e8b 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -2400,50 +2400,42 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { } } +template <typename FloatX, typename FloatY> +void set_mul_mat_f(MulMat& mm) { + for (auto& f : mm.funcs) f = nullptr; + mm.funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>; + mm.funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>; + mm.funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>; + mm.funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>; + mm.funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>; +#ifndef __AVX512F__ + mm.funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>; +#endif +} + bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { - //if (Ny == 1 && (typeA == GGML_TYPE_IQ3_S || typeA == GGML_TYPE_IQ3_XXS)) { - if (Ny == 999 && typeA == GGML_TYPE_IQ3_S) { - return false; - } + (void)Ny; - if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F32) { - for (auto& f : mm.funcs) f = nullptr; - mm.funcs[0] = mul_mat_fX_fY_T<1, ggml_half, float>; - mm.funcs[1] = mul_mat_fX_fY_T<2, ggml_half, float>; - mm.funcs[2] = mul_mat_fX_fY_T<3, ggml_half, float>; - mm.funcs[3] = mul_mat_fX_fY_T<4, ggml_half, float>; - mm.funcs[4] = mul_mat_fX_fY_T<5, ggml_half, float>; -#ifndef __AVX512F__ - mm.funcs[5] = mul_mat_fX_fY_T<6, ggml_half, float>; -#endif + if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) { + if (ne00 % QFBase::k_step) return false; + } + if (typeA == GGML_TYPE_F16) { + switch (typeB) { + case GGML_TYPE_F16: set_mul_mat_f<ggml_half, ggml_half>(mm); break; + case GGML_TYPE_F32: set_mul_mat_f<ggml_half, float>(mm); break; + default: return false; + } return true; } - if (typeA == GGML_TYPE_F32 && typeB == GGML_TYPE_F16) { - for (auto& f : mm.funcs) f = nullptr; - mm.funcs[0] = mul_mat_fX_fY_T<1, float, ggml_half>; - mm.funcs[1] = mul_mat_fX_fY_T<2, float, ggml_half>; - mm.funcs[2] = mul_mat_fX_fY_T<3, float, ggml_half>; - mm.funcs[3] = mul_mat_fX_fY_T<4, float, ggml_half>; - mm.funcs[4] = mul_mat_fX_fY_T<5, float, ggml_half>; -#ifndef __AVX512F__ - mm.funcs[5] = mul_mat_fX_fY_T<6, float, ggml_half>; -#endif + if (typeA == GGML_TYPE_F32) { + switch (typeB) { + case GGML_TYPE_F16: set_mul_mat_f<float, ggml_half>(mm); break; + case GGML_TYPE_F32: set_mul_mat_f<float, float>(mm); break; + default: return false; + } return true; } - // Using the standard legacy quant template is slightly faster than tiling - // as implemented in mul_mat_q80_q80_T -// if (typeA == GGML_TYPE_Q8_0) { -// for (auto& f : mm.funcs) f = nullptr; -// mm.funcs[0] = mul_mat_q80_q80_T<1>; -// mm.funcs[1] = mul_mat_q80_q80_T<2>; -// mm.funcs[2] = mul_mat_q80_q80_T<3>; -//#ifdef __AVX512F__ -// mm.funcs[3] = mul_mat_q80_q80_T<4>; -//#endif -// row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); -// return true; -// } auto expected_typeB = GGML_TYPE_Q8_K; @@ -2510,7 +2502,7 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_Q5_1: assert (ne00 % QK5_1 == 0); MulMat::set_functions<Q5_1_Unpacker>(mm); - expected_typeB = GGML_TYPE_Q8_0; + expected_typeB = GGML_TYPE_Q8_1; break; case GGML_TYPE_Q8_0: assert (ne00 % QK8_0 == 0); @@ -2522,9 +2514,7 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { return false; } - if (typeB != expected_typeB) return false; - - return true; + return ggml_type(typeB) == expected_typeB; } } // namespace |