diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-10-01 08:57:34 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-01 08:57:34 +0300 |
commit | 8cba4789da860d32cfc6d14f96ed37ade9e334bd (patch) | |
tree | c26b35d5a500abf53cc8ca021f4a5b5bffa1261b | |
parent | fd20638bbcb4b1ba69783312bb78545fa418d3f2 (diff) |
iqk_mul_mat: better srategy when nrc_y not divisible by ny (#71)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 39 |
1 files changed, 31 insertions, 8 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 33b0a0d5..568e577c 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -107,16 +107,39 @@ struct MulMat { while (!funcs[ny-1] && ny > 0) --ny; int n_step = (nrc_y - info.cur_y)/ny; if (n_step > 0) { - for (int ix = 0; ix < nrc_x; ix += k_x_step) { - auto this_info = info; - this_info.s += ix; - int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; - for (int iy = 0; iy < n_step; ++iy) { - funcs[ny-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); - this_info.cur_y += ny; + if (n_step*ny != nrc_y) { + ++n_step; + int ny1 = nrc_y/n_step; + int ny2 = ny1 + 1; + int my1 = n_step*ny2 - nrc_y; + int my2 = n_step - my1; + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < my1; ++iy) { + funcs[ny1-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); + this_info.cur_y += ny1; + } + for (int iy = 0; iy < my2; ++iy) { + funcs[ny2-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); + this_info.cur_y += ny2; + } + } + info.cur_y += nrc_y; + } + else { + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < n_step; ++iy) { + funcs[ny-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); + this_info.cur_y += ny; + } } + info.cur_y += ny * n_step; } - info.cur_y += ny * n_step; } int n_left = nrc_y - info.cur_y; if (n_left > 0) { |