summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c157
1 files changed, 113 insertions, 44 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index eb39d574..8efe2653 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -3845,6 +3845,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"MUL_MAT",
"MUL_MAT_ID",
"OUT_PROD",
+ "MOE_FUSED_UP_GATE",
"SCALE",
"SET",
@@ -3904,7 +3905,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
+static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3938,6 +3939,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"X*Y",
"X[i]*Y",
"X*Y",
+ "X*Y1&X*Y2",
"x*v",
"y-\\>view(x)",
@@ -3997,7 +3999,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
+static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -6768,6 +6770,51 @@ struct ggml_tensor * ggml_mul_mat_id(
return result;
}
+struct ggml_tensor * ggml_moe_up_gate(
+ struct ggml_context * ctx,
+ struct ggml_tensor * as_up,
+ struct ggml_tensor * as_gate,
+ struct ggml_tensor * b,
+ struct ggml_tensor * ids,
+ enum ggml_unary_op op) {
+ if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) {
+ struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids);
+ struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids);
+ return ggml_fused_mul_unary(ctx, result_gate, result_up, op);
+ }
+ GGML_ASSERT(!ggml_is_transposed(as_up));
+ GGML_ASSERT(!ggml_is_transposed(as_gate));
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+ GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert)
+ GGML_ASSERT(b->ne[3] == 1); // b is 3d
+ GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
+ GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
+ GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat
+ GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
+
+ bool is_node = false;
+
+ if (as_up->grad || as_gate->grad || b->grad) {
+ is_node = true;
+ }
+
+ const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ result->op = GGML_OP_MOE_FUSED_UP_GATE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = as_up;
+ result->src[1] = as_gate;
+ result->src[2] = b;
+ result->src[3] = ids;
+
+ ggml_set_op_params_i32(result, 0, (int32_t) op);
+
+ return result;
+}
+
+
// ggml_out_prod
struct ggml_tensor * ggml_out_prod(
@@ -14584,20 +14631,17 @@ IQK_MulMat_Not_Available:;
#if GGML_USE_IQK_MULMAT
static void ggml_compute_forward_mul_mat_id_up_gate(
const struct ggml_compute_params * params,
- struct ggml_tensor * dst1,
- struct ggml_tensor * dst2) {
+ struct ggml_tensor * dst) {
- GGML_ASSERT(dst1->src[1] == dst2->src[1]);
- GGML_ASSERT(dst1->src[2] == dst2->src[2]);
- GGML_ASSERT(dst1->src[0]->type == dst2->src[0]->type);
- GGML_ASSERT(dst1->type == GGML_TYPE_F32 && dst2->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->src[0]->type == dst->src[1]->type);
+ GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst->src[1]));
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
- const struct ggml_tensor * src1 = dst1->src[1];
- const struct ggml_tensor * ids = dst1->src[2];
- const struct ggml_tensor * src0_1 = dst1->src[0];
- const struct ggml_tensor * src0_2 = dst2->src[0];
- const struct ggml_tensor * src0 = src0_1;
- const struct ggml_tensor * dst = dst1; // so GGML_TENSOR_BINARY_OP_LOCALS works
+ const struct ggml_tensor * src1 = dst->src[2];
+ const struct ggml_tensor * ids = dst->src[3];
+ const struct ggml_tensor * src0_1 = dst->src[0];
+ const struct ggml_tensor * src0_2 = dst->src[1];
+ const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works
GGML_TENSOR_BINARY_OP_LOCALS
@@ -14680,6 +14724,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
ggml_barrier(params->shared);
+
+ // so GGML_TENSOR_BINARY_OP_LOCALS works
+
// compute each matrix multiplication in sequence
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
@@ -14696,28 +14743,34 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows
-
- if (nth%2 == 0) {
- const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur;
- void * dst_d = ith%2 == 0 ? dst1->data : dst2->data;
- if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
- type, src0_d, nb01,
- vec_dot_type, (const char *)wdata, row_size,
- (float *)dst_d, nb1, nb2,
- matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error");
-
- } else {
- if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
- src0_1->type, (const char *)src0_1_cur, nb01,
- vec_dot_type, (const char *)wdata, row_size,
- (float *)dst1->data, nb1, nb2,
- matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
- if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
- src0_2->type, (const char *)src0_2_cur, nb01,
- vec_dot_type, (const char *)wdata, row_size,
- (float *)dst2->data, nb1, nb2,
- matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
- }
+ //
+ if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0],
+ type, src0_1_cur, src0_2_cur, nb01,
+ vec_dot_type, (const char *)wdata, row_size,
+ (float *)dst->data, nb1, nb2,
+ matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
+
+// if (nth%2 == 0) {
+// const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur;
+// void * dst_d = ith%2 == 0 ? dst1->data : dst2->data;
+// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
+// type, src0_d, nb01,
+// vec_dot_type, (const char *)wdata, row_size,
+// (float *)dst_d, nb1, nb2,
+// matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error");
+//
+// } else {
+// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
+// src0_1->type, (const char *)src0_1_cur, nb01,
+// vec_dot_type, (const char *)wdata, row_size,
+// (float *)dst1->data, nb1, nb2,
+// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
+// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
+// src0_2->type, (const char *)src0_2_cur, nb01,
+// vec_dot_type, (const char *)wdata, row_size,
+// (float *)dst2->data, nb1, nb2,
+// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
+// }
}
#undef MMID_MATRIX_ROW
@@ -19152,6 +19205,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) {
GGML_ASSERT(params);
+ GGML_UNUSED(next);
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
return false;
@@ -19269,16 +19323,12 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_MUL_MAT_ID:
{
-#if GGML_USE_IQK_MULMAT
- if (next && next->op == GGML_OP_MUL_MAT_ID && tensor->src[1] == next->src[1] &&
- tensor->src[0]->type == next->src[0]->type) {
- ggml_compute_forward_mul_mat_id_up_gate(params, tensor, next);
- skip_next = true;
- break;
- }
-#endif
ggml_compute_forward_mul_mat_id(params, tensor);
} break;
+ case GGML_OP_MOE_FUSED_UP_GATE:
+ {
+ ggml_compute_forward_mul_mat_id_up_gate(params, tensor);
+ } break;
case GGML_OP_OUT_PROD:
{
ggml_compute_forward_out_prod(params, tensor);
@@ -20036,6 +20086,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
+ case GGML_OP_MOE_FUSED_UP_GATE:
+ {
+ GGML_ABORT("fatal error"); // TODO: not implemented
+ }
case GGML_OP_OUT_PROD:
{
GGML_ABORT("fatal error"); // TODO: not implemented
@@ -21046,6 +21100,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_CONCAT:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
+ case GGML_OP_MOE_FUSED_UP_GATE:
case GGML_OP_OUT_PROD:
{
n_tasks = n_threads;
@@ -21249,6 +21304,20 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
} break;
+ case GGML_OP_MOE_FUSED_UP_GATE:
+ {
+ cur = 0;
+ const struct ggml_tensor * src0 = node->src[0];
+ const struct ggml_tensor * src2 = node->src[2];
+ const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
+ if (src2->type != vec_dot_type) {
+ cur += ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]);
+ }
+ const int n_as = src0->ne[2];
+ cur += GGML_PAD(cur, sizeof(int64_t)); // align
+ cur += n_as * sizeof(int64_t); // matrix_row_counts
+ cur += n_as * src2->ne[2] * sizeof(int64_t); // matrix_rows
+ } break;
case GGML_OP_OUT_PROD:
{
if (ggml_is_quantized(node->src[0]->type)) {