diff options
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r-- | iqk_mul_mat.cpp | 76 |
1 files changed, 42 insertions, 34 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 9934d2e6..5f01a610 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -120,7 +120,7 @@ struct MulMat { funcs[n_left-1](n, vx, bx, info, nrc_x); } } - static bool set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int Ny); + static bool set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny); private: template <typename Dequantizer> static void set_functions(MulMat& m); }; @@ -173,43 +173,50 @@ const uint64_t keven_signs[128] = { } -bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, const void * B, +bool iqk_mul_mat(int task_type, long Nx, long Ny, long ne00, + int typeA, const void * A, long strideA, + int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth) { MulMat mm; - int row_size_q8; - if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) { + if (!MulMat::set_mul_mat(typeA, typeB, ne00, mm, Ny)) { return false; } - auto row_size_qx = ggml_row_size((ggml_type)typeA, ne00); + if (ggml_task_type(task_type) != GGML_TASK_TYPE_COMPUTE) return ggml_task_type(task_type) == GGML_TASK_TYPE_INIT; + + auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA)); + auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB)); auto nrc_x = (Nx + nth - 1)/nth; auto first_x = ith*nrc_x; if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; - DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, (size_t)row_size_q8, 0, 1, nullptr, 0}; + DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0}; mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); return true; } -bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, const void * B, +bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, + int typeA, const void * A, long strideA, + int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; assert(row_mapping != nullptr); MulMat mm; - int row_size_q8; - if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) { + if (!MulMat::set_mul_mat(typeA, typeB, ne00, mm, Ny)) { return false; } - int row_size_qx = ggml_row_size((ggml_type)typeA, ne00); + auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA)); + auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB)); int nrc_x = (Nx + nth - 1)/nth; int first_x = ith*nrc_x; if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; - DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), (size_t)row_size_q8, 0, ne11, row_mapping, nb2/sizeof(float)}; + DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), + row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); return true; } @@ -236,7 +243,6 @@ inline float hsum_float_8(__m256 x) { #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - template <int nrc, typename block_q8 = block_q8_K> struct Q8 { constexpr static int nrc_y = nrc; @@ -2394,14 +2400,14 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { } } -bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int Ny) { +bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { //if (Ny == 1 && (typeA == GGML_TYPE_IQ3_S || typeA == GGML_TYPE_IQ3_XXS)) { if (Ny == 999 && typeA == GGML_TYPE_IQ3_S) { return false; } - if (typeA == GGML_TYPE_F16) { + if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F32) { for (auto& f : mm.funcs) f = nullptr; mm.funcs[0] = mul_mat_fX_fY_T<1, ggml_half, float>; mm.funcs[1] = mul_mat_fX_fY_T<2, ggml_half, float>; @@ -2411,10 +2417,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int #ifndef __AVX512F__ mm.funcs[5] = mul_mat_fX_fY_T<6, ggml_half, float>; #endif - row_size_q8 = ggml_row_size(GGML_TYPE_F32, ne00); return true; } - if (typeA == GGML_TYPE_F32) { + if (typeA == GGML_TYPE_F32 && typeB == GGML_TYPE_F16) { for (auto& f : mm.funcs) f = nullptr; mm.funcs[0] = mul_mat_fX_fY_T<1, float, ggml_half>; mm.funcs[1] = mul_mat_fX_fY_T<2, float, ggml_half>; @@ -2424,7 +2429,6 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int #ifndef __AVX512F__ mm.funcs[5] = mul_mat_fX_fY_T<6, float, ggml_half>; #endif - row_size_q8 = ggml_row_size(GGML_TYPE_F16, ne00); return true; } // Using the standard legacy quant template is slightly faster than tiling @@ -2441,7 +2445,7 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int // return true; // } - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); + auto expected_typeB = GGML_TYPE_Q8_K; switch (typeA) { case GGML_TYPE_Q2_K: @@ -2491,33 +2495,35 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); MulMat::set_functions<Q4_0_Unpacker>(mm); - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); + expected_typeB = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q4_1: assert (ne00 % QK4_1 == 0); MulMat::set_functions<Q4_1_Unpacker>(mm); - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); + expected_typeB = GGML_TYPE_Q8_1; break; case GGML_TYPE_Q5_0: assert (ne00 % QK5_0 == 0); MulMat::set_functions<Q5_0_Unpacker>(mm); - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); + expected_typeB = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q5_1: assert (ne00 % QK5_1 == 0); MulMat::set_functions<Q5_1_Unpacker>(mm); - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); + expected_typeB = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q8_0: assert (ne00 % QK8_0 == 0); MulMat::set_functions<Q8_0_Unpacker>(mm); - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); + expected_typeB = GGML_TYPE_Q8_0; break; default: return false; } + if (typeB != expected_typeB) return false; + return true; } @@ -3882,7 +3888,7 @@ IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0 template <int nrc_y> void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n%QF16Base::k_step == 0); + GGML_ASSERT(n%QF16Base::k_step == 0); constexpr int k_nx = 5; const char * cx = (const char *)vx; for (int ix = 0; ix < nrc_x/k_nx; ++ix) { @@ -3933,10 +3939,10 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) { } } -bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /*Ny*/) { - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); +bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { - if (typeA == GGML_TYPE_F16) { + if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F16) { + if (ne00%8) return false; for (auto& f : m.funcs) f = nullptr; m.funcs[0] = mul_mat_f16_f16_T<1>; m.funcs[1] = mul_mat_f16_f16_T<2>; @@ -3945,10 +3951,11 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int / m.funcs[4] = mul_mat_f16_f16_T<5>; //m.funcs[5] = mul_mat_f16_f16_T<6>; //m.funcs[6] = mul_mat_f16_f16_T<7>; - row_size_q8 = ggml_row_size(GGML_TYPE_F16, ne00); return true; } + auto expected_Btype = GGML_TYPE_Q8_K; + switch (typeA) { case GGML_TYPE_Q2_K: MulMat::set_functions<DequantizerQ2K>(m); @@ -3985,28 +3992,29 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int / break; case GGML_TYPE_Q4_0: MulMat::set_functions<DequantizerQ40>(m); - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); + expected_Btype = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q4_1: MulMat::set_functions<DequantizerQ41>(m); - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); + expected_Btype = GGML_TYPE_Q8_1; break; case GGML_TYPE_Q5_0: MulMat::set_functions<DequantizerQ50>(m); - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); + expected_Btype = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q5_1: MulMat::set_functions<DequantizerQ51>(m); - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); + expected_Btype = GGML_TYPE_Q8_1; break; case GGML_TYPE_Q8_0: MulMat::set_functions<DequantizerQ80>(m); - row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); + expected_Btype = GGML_TYPE_Q8_0; break; default: return false; } - return true; + + return typeB == expected_Btype; } } |