diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 157 |
1 files changed, 113 insertions, 44 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eb39d574..8efe2653 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3845,6 +3845,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "MUL_MAT", "MUL_MAT_ID", "OUT_PROD", + "MOE_FUSED_UP_GATE", "SCALE", "SET", @@ -3904,7 +3905,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3938,6 +3939,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "X*Y", "X[i]*Y", "X*Y", + "X*Y1&X*Y2", "x*v", "y-\\>view(x)", @@ -3997,7 +3999,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6768,6 +6770,51 @@ struct ggml_tensor * ggml_mul_mat_id( return result; } +struct ggml_tensor * ggml_moe_up_gate( + struct ggml_context * ctx, + struct ggml_tensor * as_up, + struct ggml_tensor * as_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + enum ggml_unary_op op) { + if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) { + struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids); + struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids); + return ggml_fused_mul_unary(ctx, result_gate, result_up, op); + } + GGML_ASSERT(!ggml_is_transposed(as_up)); + GGML_ASSERT(!ggml_is_transposed(as_gate)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert) + GGML_ASSERT(b->ne[3] == 1); // b is 3d + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d + GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row + GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat + GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast + + bool is_node = false; + + if (as_up->grad || as_gate->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MOE_FUSED_UP_GATE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = as_up; + result->src[1] = as_gate; + result->src[2] = b; + result->src[3] = ids; + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + return result; +} + + // ggml_out_prod struct ggml_tensor * ggml_out_prod( @@ -14584,20 +14631,17 @@ IQK_MulMat_Not_Available:; #if GGML_USE_IQK_MULMAT static void ggml_compute_forward_mul_mat_id_up_gate( const struct ggml_compute_params * params, - struct ggml_tensor * dst1, - struct ggml_tensor * dst2) { + struct ggml_tensor * dst) { - GGML_ASSERT(dst1->src[1] == dst2->src[1]); - GGML_ASSERT(dst1->src[2] == dst2->src[2]); - GGML_ASSERT(dst1->src[0]->type == dst2->src[0]->type); - GGML_ASSERT(dst1->type == GGML_TYPE_F32 && dst2->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == dst->src[1]->type); + GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst->src[1])); + GGML_ASSERT(dst->type == GGML_TYPE_F32); - const struct ggml_tensor * src1 = dst1->src[1]; - const struct ggml_tensor * ids = dst1->src[2]; - const struct ggml_tensor * src0_1 = dst1->src[0]; - const struct ggml_tensor * src0_2 = dst2->src[0]; - const struct ggml_tensor * src0 = src0_1; - const struct ggml_tensor * dst = dst1; // so GGML_TENSOR_BINARY_OP_LOCALS works + const struct ggml_tensor * src1 = dst->src[2]; + const struct ggml_tensor * ids = dst->src[3]; + const struct ggml_tensor * src0_1 = dst->src[0]; + const struct ggml_tensor * src0_2 = dst->src[1]; + const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works GGML_TENSOR_BINARY_OP_LOCALS @@ -14680,6 +14724,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate( ggml_barrier(params->shared); + + // so GGML_TENSOR_BINARY_OP_LOCALS works + // compute each matrix multiplication in sequence for (int cur_a = 0; cur_a < n_as; ++cur_a) { const int64_t cne1 = matrix_row_counts[cur_a]; @@ -14696,28 +14743,34 @@ static void ggml_compute_forward_mul_mat_id_up_gate( const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = cne1; // src1 rows - - if (nth%2 == 0) { - const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur; - void * dst_d = ith%2 == 0 ? dst1->data : dst2->data; - if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, - type, src0_d, nb01, - vec_dot_type, (const char *)wdata, row_size, - (float *)dst_d, nb1, nb2, - matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error"); - - } else { - if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, - src0_1->type, (const char *)src0_1_cur, nb01, - vec_dot_type, (const char *)wdata, row_size, - (float *)dst1->data, nb1, nb2, - matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); - if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, - src0_2->type, (const char *)src0_2_cur, nb01, - vec_dot_type, (const char *)wdata, row_size, - (float *)dst2->data, nb1, nb2, - matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); - } + // + if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0], + type, src0_1_cur, src0_2_cur, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)dst->data, nb1, nb2, + matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); + +// if (nth%2 == 0) { +// const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur; +// void * dst_d = ith%2 == 0 ? dst1->data : dst2->data; +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// type, src0_d, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst_d, nb1, nb2, +// matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error"); +// +// } else { +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// src0_1->type, (const char *)src0_1_cur, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst1->data, nb1, nb2, +// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// src0_2->type, (const char *)src0_2_cur, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst2->data, nb1, nb2, +// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); +// } } #undef MMID_MATRIX_ROW @@ -19152,6 +19205,7 @@ static void ggml_compute_forward_cross_entropy_loss_back( static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) { GGML_ASSERT(params); + GGML_UNUSED(next); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { return false; @@ -19269,16 +19323,12 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT_ID: { -#if GGML_USE_IQK_MULMAT - if (next && next->op == GGML_OP_MUL_MAT_ID && tensor->src[1] == next->src[1] && - tensor->src[0]->type == next->src[0]->type) { - ggml_compute_forward_mul_mat_id_up_gate(params, tensor, next); - skip_next = true; - break; - } -#endif ggml_compute_forward_mul_mat_id(params, tensor); } break; + case GGML_OP_MOE_FUSED_UP_GATE: + { + ggml_compute_forward_mul_mat_id_up_gate(params, tensor); + } break; case GGML_OP_OUT_PROD: { ggml_compute_forward_out_prod(params, tensor); @@ -20036,6 +20086,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_MOE_FUSED_UP_GATE: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_OUT_PROD: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -21046,6 +21100,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_CONCAT: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: + case GGML_OP_MOE_FUSED_UP_GATE: case GGML_OP_OUT_PROD: { n_tasks = n_threads; @@ -21249,6 +21304,20 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows } break; + case GGML_OP_MOE_FUSED_UP_GATE: + { + cur = 0; + const struct ggml_tensor * src0 = node->src[0]; + const struct ggml_tensor * src2 = node->src[2]; + const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; + if (src2->type != vec_dot_type) { + cur += ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); + } + const int n_as = src0->ne[2]; + cur += GGML_PAD(cur, sizeof(int64_t)); // align + cur += n_as * sizeof(int64_t); // matrix_row_counts + cur += n_as * src2->ne[2] * sizeof(int64_t); // matrix_rows + } break; case GGML_OP_OUT_PROD: { if (ggml_is_quantized(node->src[0]->type)) { |