diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 92 |
1 files changed, 87 insertions, 5 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index d04ad22a..6f8c0106 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -239,6 +239,7 @@ struct MulMat { case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; + case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } #else @@ -327,6 +328,89 @@ static std::vector<char> & thread_local_work_buffer() { return f; } +bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, size_t stride_y, int nrc_x) { + + switch (typeA) { + //case GGML_TYPE_F16: + //case GGML_TYPE_F32: + //case GGML_TYPE_BF16: + //case GGML_TYPE_BF16_R16: + // return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs); + //case GGML_TYPE_Q2_K: + //case GGML_TYPE_Q3_K: + //case GGML_TYPE_Q4_K: + //case GGML_TYPE_Q5_K: + //case GGML_TYPE_Q6_K: + //case GGML_TYPE_IQ4_XS: + //case GGML_TYPE_Q2_K_R4: + //case GGML_TYPE_Q3_K_R4: + //case GGML_TYPE_Q4_K_R4: + //case GGML_TYPE_Q5_K_R4: + //case GGML_TYPE_Q6_K_R4: + //case GGML_TYPE_IQ4_XS_R8: + //case GGML_TYPE_Q8_K_R8: + //case GGML_TYPE_Q8_KV: + //case GGML_TYPE_Q8_KV_R8: + // return iqk_set_kernels_kquants(ne00, typeA, typeB, mm.funcs, mm.func16); + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_XXS_R4: + case GGML_TYPE_IQ2_XS_R4: + case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ3_XXS_R4: + case GGML_TYPE_IQ3_S_R4: + return iqk_convert_iquants_q80_r8(typeA, n, vx, bx, vy, nrc_x); + //case GGML_TYPE_IQ4_KS: + //case GGML_TYPE_IQ5_KS: + //case GGML_TYPE_IQ4_KSS: + //case GGML_TYPE_IQ2_K: + //case GGML_TYPE_IQ2_KS: + //case GGML_TYPE_IQ3_K: + //case GGML_TYPE_IQ4_K: + //case GGML_TYPE_IQ5_K: + //case GGML_TYPE_IQ6_K: + //case GGML_TYPE_IQ2_K_R4: + //case GGML_TYPE_IQ3_K_R4: + //case GGML_TYPE_IQ4_K_R4: + //case GGML_TYPE_IQ5_K_R4: + //case GGML_TYPE_IQ4_KS_R4: + //case GGML_TYPE_IQ5_KS_R4: + // return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs, mm.func16); + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: + return iqk_dequantize_ktquants(typeA, n, vx, bx, vy, stride_y, nrc_x); + //case GGML_TYPE_Q4_0: + //case GGML_TYPE_Q4_1: + //case GGML_TYPE_Q5_0: + //case GGML_TYPE_Q5_1: + //case GGML_TYPE_Q6_0: + //case GGML_TYPE_Q8_0: + //case GGML_TYPE_IQ4_NL: + //case GGML_TYPE_Q4_0_R8: + //case GGML_TYPE_Q5_0_R4: + //case GGML_TYPE_Q6_0_R4: + //case GGML_TYPE_Q8_0_R8: + //case GGML_TYPE_IQ4_NL_R4: + // return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16); + //case GGML_TYPE_IQ1_S: + //case GGML_TYPE_IQ1_S_R4: + //case GGML_TYPE_IQ1_M_R4: + //case GGML_TYPE_IQ1_BN: + //case GGML_TYPE_IQ2_BN: + //case GGML_TYPE_IQ2_BN_R4: + // return iqk_set_kernels_1bit(ne00, typeA, typeB, mm.funcs, mm.func16); + + default: + return false; + } + + return false; +} + } extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00, @@ -352,9 +436,7 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00, first_x *= num_rows; nrc_x *= num_rows; - auto type_size = ggml_type_size(dequant_type); - - size_t row_size_qx = ne00*type_size; + size_t row_size_qx = ggml_row_size(dequant_type, ne00); size_t row_size_qy = strideB; //printf("Dequant mul mat %s x %s: ne00 = %d, row_size = %d\n", ggml_type_name(dequant_type), ggml_type_name(ggml_type(typeB)), (int)ne00, (int)row_size_qx); @@ -368,7 +450,7 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00, this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; if (f.size() < row_size_qx*this_nrc_x) f.resize(row_size_qx*this_nrc_x); - if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) { + if (!iqk_convert_repack(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) { GGML_ABORT("Fatal error"); } mm.mul_mat_NxM(ne00, f.data(), row_size_qx, this_info, this_nrc_x, Ny); @@ -678,7 +760,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_S_R4: - return ggml_type(typeB) == GGML_TYPE_Q8_K ? iqk_set_kernels_iquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false; + return iqk_set_kernels_iquants(ne00, typeA, typeB, mm.funcs, mm.func16); case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ4_KSS: |