diff options
-rw-r--r-- | ggml/src/ggml.c | 161 |
1 files changed, 158 insertions, 3 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ad092923..eb39d574 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14581,6 +14581,149 @@ IQK_MulMat_Not_Available:; #undef MMID_MATRIX_ROW } +#if GGML_USE_IQK_MULMAT +static void ggml_compute_forward_mul_mat_id_up_gate( + const struct ggml_compute_params * params, + struct ggml_tensor * dst1, + struct ggml_tensor * dst2) { + + GGML_ASSERT(dst1->src[1] == dst2->src[1]); + GGML_ASSERT(dst1->src[2] == dst2->src[2]); + GGML_ASSERT(dst1->src[0]->type == dst2->src[0]->type); + GGML_ASSERT(dst1->type == GGML_TYPE_F32 && dst2->type == GGML_TYPE_F32); + + const struct ggml_tensor * src1 = dst1->src[1]; + const struct ggml_tensor * ids = dst1->src[2]; + const struct ggml_tensor * src0_1 = dst1->src[0]; + const struct ggml_tensor * src0_2 = dst2->src[0]; + const struct ggml_tensor * src0 = src0_1; + const struct ggml_tensor * dst = dst1; // so GGML_TENSOR_BINARY_OP_LOCALS works + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + GGML_ASSERT(ne13 == 1); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + char * wdata_src1_end = (src1->type == vec_dot_type) ? + (char *) params->wdata : + (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t)); + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] + + if (src1->type != vec_dot_type) { + + ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; + + char * wdata = params->wdata; + + const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; + + assert(params->wsize >= ne13*nbw3); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); + } + } + } + } + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] + + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); + + // group rows by src0 matrix + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + assert(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; + matrix_row_counts[i02] += 1; + } + } + } + + ggml_barrier(params->shared); + + // compute each matrix multiplication in sequence + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02; + const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02; + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1; // src1 rows + + if (nth%2 == 0) { + const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur; + void * dst_d = ith%2 == 0 ? dst1->data : dst2->data; + if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, + type, src0_d, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)dst_d, nb1, nb2, + matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error"); + + } else { + if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, + src0_1->type, (const char *)src0_1_cur, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)dst1->data, nb1, nb2, + matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); + if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, + src0_2->type, (const char *)src0_2_cur, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)dst2->data, nb1, nb2, + matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); + } + } + +#undef MMID_MATRIX_ROW +} +#endif + // ggml_compute_forward_out_prod static void ggml_compute_forward_out_prod_f32( @@ -19007,17 +19150,18 @@ static void ggml_compute_forward_cross_entropy_loss_back( ///////////////////////////////// -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { +static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) { GGML_ASSERT(params); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { - return; + return false; } #if IK_PRINT_TIMING int64_t t1 = ggml_time_us(); #endif + bool skip_next = false; switch (tensor->op) { case GGML_OP_DUP: { @@ -19125,6 +19269,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT_ID: { +#if GGML_USE_IQK_MULMAT + if (next && next->op == GGML_OP_MUL_MAT_ID && tensor->src[1] == next->src[1] && + tensor->src[0]->type == next->src[0]->type) { + ggml_compute_forward_mul_mat_id_up_gate(params, tensor, next); + skip_next = true; + break; + } +#endif ggml_compute_forward_mul_mat_id(params, tensor); } break; case GGML_OP_OUT_PROD: @@ -19367,6 +19519,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm int64_t t2 = ggml_time_us(); if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1)); #endif + return skip_next; } //////////////////////////////////////////////////////////////////////////////// @@ -21219,7 +21372,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (ggml_is_noop(node)) continue; - ggml_compute_forward(¶ms, node); + if (ggml_compute_forward(¶ms, node, node_n < cgraph->n_nodes-1 ? cgraph->nodes[node_n+1] : NULL)) { + ++node_n; + } if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { state->shared->ec = GGML_STATUS_ABORTED; |