diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 37 |
1 files changed, 24 insertions, 13 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index b23dc6d4..0b29a572 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -239,12 +239,17 @@ struct MulMat { case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; - case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; - case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; - case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; - case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ2_S : return nrc_y >= 16 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ1_M : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_Q3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; + case GGML_TYPE_Q6_K : return nrc_y >= 64 ? GGML_TYPE_Q8_0_R8 : type; default: break; } #else @@ -344,10 +349,10 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, //case GGML_TYPE_BF16_R16: // return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs); //case GGML_TYPE_Q2_K: - //case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - //case GGML_TYPE_Q6_K: + case GGML_TYPE_Q6_K: //case GGML_TYPE_IQ4_XS: //case GGML_TYPE_Q2_K_R4: //case GGML_TYPE_Q3_K_R4: @@ -404,6 +409,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, //case GGML_TYPE_IQ4_NL_R4: // return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16); case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: //case GGML_TYPE_IQ1_S_R4: //case GGML_TYPE_IQ1_M_R4: //case GGML_TYPE_IQ1_BN: @@ -420,6 +426,10 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, } +extern "C" IQK_API int iqk_dequant_type(int type, int Ny) { + return MulMat::is_dequant_better(ggml_type(type), Ny); +} + extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, @@ -597,7 +607,12 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, MulMat mm; auto etypeA = ggml_type(typeA); - if (auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); dequant_type != etypeA) { + //auto etypeB = ggml_type(typeB); + auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); + //if (etypeB != GGML_TYPE_F32) { + // if (ith == 0) printf("%s: typeA = %s, typeB = %s, dequant_type = %s\n", __func__, ggml_type_name(etypeA), ggml_type_name(etypeB), ggml_type_name(dequant_type)); + //} + if (dequant_type != etypeA) { if (!MulMat::prepare(dequant_type, typeB, ne00, mm, Ny)) { return false; } @@ -612,9 +627,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, first_x *= num_rows; nrc_x *= num_rows; - auto type_size = ggml_type_size(dequant_type); - - size_t row_size_qx = ne00*type_size; + size_t row_size_qx = ggml_row_size(dequant_type, ne00); size_t row_size_qy = strideB; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; @@ -680,9 +693,7 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n first_x *= num_rows; nrc_x *= num_rows; - auto type_size = ggml_type_size(dequant_type); - - size_t row_size_qx = ne00*type_size; + size_t row_size_qx = ggml_row_size(dequant_type, ne00); size_t row_size_qy = strideB; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; |