diff options
-rw-r--r-- | examples/batched-bench/batched-bench.cpp | 2 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 30 |
2 files changed, 21 insertions, 11 deletions
diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 25e7c775..55f825fe 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -139,6 +139,8 @@ int main(int argc, char ** argv) { const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg); if (n_ctx_req > n_kv_max) { + printf("n_ctx_req = %d is greater than n_kv_max = %d for pp = %d, tg = %d, pl = %d\n", + n_ctx_req, n_kv_max, pp, tg, pl); continue; } diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f95ce061..cfca477d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -142,13 +142,14 @@ struct MulMat { } int ny = funcs.size(); while (!funcs[ny-1] && ny > 0) --ny; - int n_step = (nrc_y - info.cur_y)/ny; + int n_left = nrc_y - info.cur_y; + int n_step = n_left/ny; if (n_step > 0) { - if (n_step*ny != nrc_y) { + if (n_step*ny != n_left) { ++n_step; - int ny1 = nrc_y/n_step; + int ny1 = n_left/n_step; int ny2 = ny1 + 1; - int my1 = n_step*ny2 - nrc_y; + int my1 = n_step*ny2 - n_left; int my2 = n_step - my1; for (int ix = 0; ix < nrc_x; ix += k_x_step) { auto this_info = info; @@ -163,7 +164,7 @@ struct MulMat { this_info.cur_y += ny2; } } - info.cur_y += nrc_y; + info.cur_y += n_left; } else { for (int ix = 0; ix < nrc_x; ix += k_x_step) { @@ -178,7 +179,7 @@ struct MulMat { info.cur_y += ny * n_step; } } - int n_left = nrc_y - info.cur_y; + n_left = nrc_y - info.cur_y; if (n_left > 0) { funcs[n_left-1](n, vx, bx, info, nrc_x); } @@ -203,6 +204,8 @@ struct MulMat { case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ2_XXS_R4: + case GGML_TYPE_IQ2_XS_R4: + case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_BN_R4: return 4; @@ -255,11 +258,15 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { return false; } - size_t row_size_qx = strideA; //*ggml_type_size(ggml_type(typeA)); - size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB)); - int nrc_x = (Nx + nth - 1)/nth; - int first_x = ith*nrc_x; - if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; + size_t row_size_qx = strideA; + size_t row_size_qy = strideB; + auto num_rows = MulMat::num_rows(ggml_type(typeA)); + GGML_ASSERT(Nx%num_rows == 0); + auto nrc_x = (Nx/num_rows + nth - 1)/nth; + auto first_x = ith*nrc_x; + if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x; + first_x *= num_rows; + nrc_x *= num_rows; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); @@ -13597,6 +13604,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in #ifdef __aarch64__ float16_t q_f16[D*q_step]; #endif + for (int i1 = 0; i1 < nq1/q_step; ++i1) { fms.init_qstep(); kh.reset_block(); |