summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp92
1 files changed, 87 insertions, 5 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index d04ad22a..6f8c0106 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -239,6 +239,7 @@ struct MulMat {
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
+ case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
default: break;
}
#else
@@ -327,6 +328,89 @@ static std::vector<char> & thread_local_work_buffer() {
return f;
}
+bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, size_t stride_y, int nrc_x) {
+
+ switch (typeA) {
+ //case GGML_TYPE_F16:
+ //case GGML_TYPE_F32:
+ //case GGML_TYPE_BF16:
+ //case GGML_TYPE_BF16_R16:
+ // return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs);
+ //case GGML_TYPE_Q2_K:
+ //case GGML_TYPE_Q3_K:
+ //case GGML_TYPE_Q4_K:
+ //case GGML_TYPE_Q5_K:
+ //case GGML_TYPE_Q6_K:
+ //case GGML_TYPE_IQ4_XS:
+ //case GGML_TYPE_Q2_K_R4:
+ //case GGML_TYPE_Q3_K_R4:
+ //case GGML_TYPE_Q4_K_R4:
+ //case GGML_TYPE_Q5_K_R4:
+ //case GGML_TYPE_Q6_K_R4:
+ //case GGML_TYPE_IQ4_XS_R8:
+ //case GGML_TYPE_Q8_K_R8:
+ //case GGML_TYPE_Q8_KV:
+ //case GGML_TYPE_Q8_KV_R8:
+ // return iqk_set_kernels_kquants(ne00, typeA, typeB, mm.funcs, mm.func16);
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ2_XXS_R4:
+ case GGML_TYPE_IQ2_XS_R4:
+ case GGML_TYPE_IQ2_S_R4:
+ case GGML_TYPE_IQ3_XXS_R4:
+ case GGML_TYPE_IQ3_S_R4:
+ return iqk_convert_iquants_q80_r8(typeA, n, vx, bx, vy, nrc_x);
+ //case GGML_TYPE_IQ4_KS:
+ //case GGML_TYPE_IQ5_KS:
+ //case GGML_TYPE_IQ4_KSS:
+ //case GGML_TYPE_IQ2_K:
+ //case GGML_TYPE_IQ2_KS:
+ //case GGML_TYPE_IQ3_K:
+ //case GGML_TYPE_IQ4_K:
+ //case GGML_TYPE_IQ5_K:
+ //case GGML_TYPE_IQ6_K:
+ //case GGML_TYPE_IQ2_K_R4:
+ //case GGML_TYPE_IQ3_K_R4:
+ //case GGML_TYPE_IQ4_K_R4:
+ //case GGML_TYPE_IQ5_K_R4:
+ //case GGML_TYPE_IQ4_KS_R4:
+ //case GGML_TYPE_IQ5_KS_R4:
+ // return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
+ return iqk_dequantize_ktquants(typeA, n, vx, bx, vy, stride_y, nrc_x);
+ //case GGML_TYPE_Q4_0:
+ //case GGML_TYPE_Q4_1:
+ //case GGML_TYPE_Q5_0:
+ //case GGML_TYPE_Q5_1:
+ //case GGML_TYPE_Q6_0:
+ //case GGML_TYPE_Q8_0:
+ //case GGML_TYPE_IQ4_NL:
+ //case GGML_TYPE_Q4_0_R8:
+ //case GGML_TYPE_Q5_0_R4:
+ //case GGML_TYPE_Q6_0_R4:
+ //case GGML_TYPE_Q8_0_R8:
+ //case GGML_TYPE_IQ4_NL_R4:
+ // return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
+ //case GGML_TYPE_IQ1_S:
+ //case GGML_TYPE_IQ1_S_R4:
+ //case GGML_TYPE_IQ1_M_R4:
+ //case GGML_TYPE_IQ1_BN:
+ //case GGML_TYPE_IQ2_BN:
+ //case GGML_TYPE_IQ2_BN_R4:
+ // return iqk_set_kernels_1bit(ne00, typeA, typeB, mm.funcs, mm.func16);
+
+ default:
+ return false;
+ }
+
+ return false;
+}
+
}
extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
@@ -352,9 +436,7 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
first_x *= num_rows;
nrc_x *= num_rows;
- auto type_size = ggml_type_size(dequant_type);
-
- size_t row_size_qx = ne00*type_size;
+ size_t row_size_qx = ggml_row_size(dequant_type, ne00);
size_t row_size_qy = strideB;
//printf("Dequant mul mat %s x %s: ne00 = %d, row_size = %d\n", ggml_type_name(dequant_type), ggml_type_name(ggml_type(typeB)), (int)ne00, (int)row_size_qx);
@@ -368,7 +450,7 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
if (f.size() < row_size_qx*this_nrc_x) f.resize(row_size_qx*this_nrc_x);
- if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) {
+ if (!iqk_convert_repack(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) {
GGML_ABORT("Fatal error");
}
mm.mul_mat_NxM(ne00, f.data(), row_size_qx, this_info, this_nrc_x, Ny);
@@ -678,7 +760,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_IQ2_S_R4:
case GGML_TYPE_IQ3_XXS_R4:
case GGML_TYPE_IQ3_S_R4:
- return ggml_type(typeB) == GGML_TYPE_Q8_K ? iqk_set_kernels_iquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false;
+ return iqk_set_kernels_iquants(ne00, typeA, typeB, mm.funcs, mm.func16);
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ5_KS:
case GGML_TYPE_IQ4_KSS: