summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml.c12
1 files changed, 8 insertions, 4 deletions
diff --git a/ggml.c b/ggml.c
index 6781a5f3..985fc26f 100644
--- a/ggml.c
+++ b/ggml.c
@@ -12544,9 +12544,6 @@ static void ggml_compute_forward_mul_mat_id(
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
if (params->type == GGML_TASK_TYPE_INIT) {
- if (ith != 0) {
- return;
- }
char * wdata = params->wdata;
if (src1->type != vec_dot_type) {
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -12554,16 +12551,23 @@ static void ggml_compute_forward_mul_mat_id(
assert(params->wsize >= ne11*ne12*ne13*row_size);
assert(src1->type == GGML_TYPE_F32);
+ int chore = 0;
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
- from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+ if (chore++ % nth == ith) {
+ from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+ }
wdata += row_size;
}
}
}
}
+ if (ith != 0) {
+ return;
+ }
+
// initialize matrix_row_counts
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));