summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-06-05 08:30:35 +0300
committerGitHub <noreply@github.com>2025-06-05 08:30:35 +0300
commit0b10f7418f7315ef90e35da49e0c053b395fd528 (patch)
treef5210b364436da3aa058a2c3c4f665aff2200470
parent7e79665a31129597634bcef403512aaf4fcdeef9 (diff)
Faster CPU prompt processing for Trellis quants and MoE models (#488)
* Also do the dequantize approach for mul_mat_id * Also do the dequantize approach for iqk_moe_fused_up_gate --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp96
1 files changed, 94 insertions, 2 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index ea8a60d4..d04ad22a 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -322,6 +322,11 @@ struct MulMat {
}
};
+static std::vector<char> & thread_local_work_buffer() {
+ thread_local std::vector<char> f;
+ return f;
+}
+
}
extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
@@ -349,8 +354,6 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
auto type_size = ggml_type_size(dequant_type);
- thread_local std::vector<char> f;
-
size_t row_size_qx = ne00*type_size;
size_t row_size_qy = strideB;
@@ -358,6 +361,8 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
+ auto& f = thread_local_work_buffer();
+
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
auto this_info = info;
this_info.s += ix;
@@ -501,6 +506,47 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
assert(row_mapping != nullptr);
MulMat mm;
+
+ auto etypeA = ggml_type(typeA);
+ if (auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); dequant_type != etypeA) {
+ if (!MulMat::prepare(dequant_type, typeB, ne00, mm, Ny)) {
+ return false;
+ }
+
+ constexpr int k_x_step = 32;
+
+ auto num_rows = MulMat::num_rows(ggml_type(dequant_type));
+ GGML_ASSERT(Nx%num_rows == 0);
+ auto nrc_x = (Nx/num_rows + nth - 1)/nth;
+ auto first_x = ith*nrc_x;
+ if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
+ first_x *= num_rows;
+ nrc_x *= num_rows;
+
+ auto type_size = ggml_type_size(dequant_type);
+
+ size_t row_size_qx = ne00*type_size;
+ size_t row_size_qy = strideB;
+
+ DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
+
+ auto& f = thread_local_work_buffer();
+
+ for (int ix = 0; ix < nrc_x; ix += k_x_step) {
+ auto this_info = info;
+ this_info.s += ix;
+ int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
+ if (f.size() < row_size_qx*this_nrc_x) f.resize(row_size_qx*this_nrc_x);
+ if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) {
+ GGML_ABORT("Fatal error");
+ }
+ mm.mul_mat_NxM(ne00, f.data(), row_size_qx, this_info, this_nrc_x, Ny);
+ }
+
+ return true;
+
+ }
+
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
return false;
}
@@ -528,6 +574,52 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
assert(row_mapping != nullptr);
MulMat mm;
+
+ auto etypeA = ggml_type(typeA);
+ if (auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); dequant_type != etypeA) {
+ if (!MulMat::prepare(dequant_type, typeB, ne00, mm, Ny)) {
+ return false;
+ }
+
+ constexpr int k_x_step = 64;
+
+ auto num_rows = MulMat::num_rows(ggml_type(dequant_type));
+ GGML_ASSERT(Nx%num_rows == 0);
+ auto nrc_x = (Nx/num_rows + nth - 1)/nth;
+ auto first_x = ith*nrc_x;
+ if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
+ first_x *= num_rows;
+ nrc_x *= num_rows;
+
+ auto type_size = ggml_type_size(dequant_type);
+
+ size_t row_size_qx = ne00*type_size;
+ size_t row_size_qy = strideB;
+
+ DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
+
+ auto& f = thread_local_work_buffer();
+
+ for (int ix = 0; ix < nrc_x; ix += k_x_step) {
+ auto this_info = info;
+ this_info.s += ix;
+ int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
+ if (f.size() < 2*row_size_qx*this_nrc_x) f.resize(2*row_size_qx*this_nrc_x);
+ auto Xu = f.data();
+ auto Xg = f.data() + row_size_qx*this_nrc_x;
+ if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)Aup + (first_x + ix)*strideA, strideA, Xu, ne00, this_nrc_x)) {
+ GGML_ABORT("Fatal error");
+ }
+ if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
+ GGML_ABORT("Fatal error");
+ }
+ mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op);
+ }
+
+ return true;
+
+ }
+
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
return false;
}