summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-10-01 08:57:34 +0300
committerGitHub <noreply@github.com>2024-10-01 08:57:34 +0300
commit8cba4789da860d32cfc6d14f96ed37ade9e334bd (patch)
treec26b35d5a500abf53cc8ca021f4a5b5bffa1261b
parentfd20638bbcb4b1ba69783312bb78545fa418d3f2 (diff)
iqk_mul_mat: better srategy when nrc_y not divisible by ny (#71)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp39
1 files changed, 31 insertions, 8 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 33b0a0d5..568e577c 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -107,16 +107,39 @@ struct MulMat {
while (!funcs[ny-1] && ny > 0) --ny;
int n_step = (nrc_y - info.cur_y)/ny;
if (n_step > 0) {
- for (int ix = 0; ix < nrc_x; ix += k_x_step) {
- auto this_info = info;
- this_info.s += ix;
- int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
- for (int iy = 0; iy < n_step; ++iy) {
- funcs[ny-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);
- this_info.cur_y += ny;
+ if (n_step*ny != nrc_y) {
+ ++n_step;
+ int ny1 = nrc_y/n_step;
+ int ny2 = ny1 + 1;
+ int my1 = n_step*ny2 - nrc_y;
+ int my2 = n_step - my1;
+ for (int ix = 0; ix < nrc_x; ix += k_x_step) {
+ auto this_info = info;
+ this_info.s += ix;
+ int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
+ for (int iy = 0; iy < my1; ++iy) {
+ funcs[ny1-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);
+ this_info.cur_y += ny1;
+ }
+ for (int iy = 0; iy < my2; ++iy) {
+ funcs[ny2-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);
+ this_info.cur_y += ny2;
+ }
+ }
+ info.cur_y += nrc_y;
+ }
+ else {
+ for (int ix = 0; ix < nrc_x; ix += k_x_step) {
+ auto this_info = info;
+ this_info.s += ix;
+ int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
+ for (int iy = 0; iy < n_step; ++iy) {
+ funcs[ny-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);
+ this_info.cur_y += ny;
+ }
}
+ info.cur_y += ny * n_step;
}
- info.cur_y += ny * n_step;
}
int n_left = nrc_y - info.cur_y;
if (n_left > 0) {