summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml.c161
1 files changed, 158 insertions, 3 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index ad092923..eb39d574 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -14581,6 +14581,149 @@ IQK_MulMat_Not_Available:;
#undef MMID_MATRIX_ROW
}
+#if GGML_USE_IQK_MULMAT
+static void ggml_compute_forward_mul_mat_id_up_gate(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst1,
+ struct ggml_tensor * dst2) {
+
+ GGML_ASSERT(dst1->src[1] == dst2->src[1]);
+ GGML_ASSERT(dst1->src[2] == dst2->src[2]);
+ GGML_ASSERT(dst1->src[0]->type == dst2->src[0]->type);
+ GGML_ASSERT(dst1->type == GGML_TYPE_F32 && dst2->type == GGML_TYPE_F32);
+
+ const struct ggml_tensor * src1 = dst1->src[1];
+ const struct ggml_tensor * ids = dst1->src[2];
+ const struct ggml_tensor * src0_1 = dst1->src[0];
+ const struct ggml_tensor * src0_2 = dst2->src[0];
+ const struct ggml_tensor * src0 = src0_1;
+ const struct ggml_tensor * dst = dst1; // so GGML_TENSOR_BINARY_OP_LOCALS works
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const enum ggml_type type = src0->type;
+
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
+
+ // we don't support permuted src0 or src1
+ GGML_ASSERT(nb00 == ggml_type_size(type));
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+ GGML_ASSERT(ne13 == 1);
+
+ // row groups
+ const int n_ids = ids->ne[0]; // n_expert_used
+ const int n_as = ne02; // n_expert
+
+ char * wdata_src1_end = (src1->type == vec_dot_type) ?
+ (char *) params->wdata :
+ (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t));
+
+ struct mmid_row_mapping {
+ int32_t i1;
+ int32_t i2;
+ };
+
+ int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
+
+ if (src1->type != vec_dot_type) {
+
+ ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
+
+ char * wdata = params->wdata;
+
+ const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
+ const size_t nbw2 = nbw1*ne11;
+ const size_t nbw3 = nbw2*ne12;
+
+ assert(params->wsize >= ne13*nbw3);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
+ for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+ from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+ ne10);
+ }
+ }
+ }
+ }
+
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
+
+ if (ith == 0) {
+ // initialize matrix_row_counts
+ memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
+
+ // group rows by src0 matrix
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
+ for (int id = 0; id < n_ids; ++id) {
+ const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
+
+ assert(i02 >= 0 && i02 < n_as);
+
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
+ matrix_row_counts[i02] += 1;
+ }
+ }
+ }
+
+ ggml_barrier(params->shared);
+
+ // compute each matrix multiplication in sequence
+ for (int cur_a = 0; cur_a < n_as; ++cur_a) {
+ const int64_t cne1 = matrix_row_counts[cur_a];
+
+ if (cne1 == 0) {
+ continue;
+ }
+
+ const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
+ const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
+
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
+ const int64_t nr0 = ne01; // src0 rows
+ const int64_t nr1 = cne1; // src1 rows
+
+ if (nth%2 == 0) {
+ const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur;
+ void * dst_d = ith%2 == 0 ? dst1->data : dst2->data;
+ if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
+ type, src0_d, nb01,
+ vec_dot_type, (const char *)wdata, row_size,
+ (float *)dst_d, nb1, nb2,
+ matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error");
+
+ } else {
+ if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
+ src0_1->type, (const char *)src0_1_cur, nb01,
+ vec_dot_type, (const char *)wdata, row_size,
+ (float *)dst1->data, nb1, nb2,
+ matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
+ if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
+ src0_2->type, (const char *)src0_2_cur, nb01,
+ vec_dot_type, (const char *)wdata, row_size,
+ (float *)dst2->data, nb1, nb2,
+ matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
+ }
+ }
+
+#undef MMID_MATRIX_ROW
+}
+#endif
+
// ggml_compute_forward_out_prod
static void ggml_compute_forward_out_prod_f32(
@@ -19007,17 +19150,18 @@ static void ggml_compute_forward_cross_entropy_loss_back(
/////////////////////////////////
-static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) {
GGML_ASSERT(params);
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
- return;
+ return false;
}
#if IK_PRINT_TIMING
int64_t t1 = ggml_time_us();
#endif
+ bool skip_next = false;
switch (tensor->op) {
case GGML_OP_DUP:
{
@@ -19125,6 +19269,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_MUL_MAT_ID:
{
+#if GGML_USE_IQK_MULMAT
+ if (next && next->op == GGML_OP_MUL_MAT_ID && tensor->src[1] == next->src[1] &&
+ tensor->src[0]->type == next->src[0]->type) {
+ ggml_compute_forward_mul_mat_id_up_gate(params, tensor, next);
+ skip_next = true;
+ break;
+ }
+#endif
ggml_compute_forward_mul_mat_id(params, tensor);
} break;
case GGML_OP_OUT_PROD:
@@ -19367,6 +19519,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
int64_t t2 = ggml_time_us();
if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1));
#endif
+ return skip_next;
}
////////////////////////////////////////////////////////////////////////////////
@@ -21219,7 +21372,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
if (ggml_is_noop(node)) continue;
- ggml_compute_forward(&params, node);
+ if (ggml_compute_forward(&params, node, node_n < cgraph->n_nodes-1 ? cgraph->nodes[node_n+1] : NULL)) {
+ ++node_n;
+ }
if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->shared->ec = GGML_STATUS_ABORTED;