summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c91
1 files changed, 89 insertions, 2 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 5570b1fc..39218ff4 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -3338,6 +3338,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GROUP_NORM",
"FUSED_RMS_NORM",
"FUSED_MUL_UNARY",
+ "MULTI_ADD",
"MUL_MAT",
"MUL_MAT_ID",
@@ -3401,7 +3402,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3430,6 +3431,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"group_norm(x)",
"fused_rms_norm(x)",
"fused_mul_unary(x)",
+ "x1+x2+x3+...",
"X*Y",
"X[i]*Y",
@@ -3493,7 +3495,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -5106,6 +5108,29 @@ struct ggml_tensor * ggml_add_inplace(
return ggml_add_impl(ctx, a, b, true);
}
+// ggml_add
+
+struct ggml_tensor * ggml_multi_add(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_experts) {
+
+ bool is_node = false;
+
+ if (n_experts < 1) {
+ GGML_ABORT("fatal error");
+ }
+
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_MULTI_ADD;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->op_params[0] = n_experts;
+
+ return result;
+}
+
// ggml_add_cast
static struct ggml_tensor * ggml_add_cast_impl(
@@ -10425,6 +10450,59 @@ static void ggml_compute_forward_add(
}
}
+static void ggml_compute_forward_multi_add_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ struct ggml_tensor * src = dst->src[0];
+
+ GGML_ASSERT(dst->nb[0] == sizeof(float));
+ GGML_ASSERT(src->nb[0] == sizeof(float));
+ GGML_ASSERT(ggml_are_same_shape(src, dst));
+ GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1);
+
+ const int n_add = dst->op_params[0];
+ GGML_ASSERT(n_add > 0);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(dst);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ int64_t ne0 = dst->ne[0];
+
+ for (int i1 = ir0; i1 < ir1; ++i1) {
+
+ float * dst_ptr = (float *) ((char *) dst->data + i1*dst->nb[1] );
+ const float * data = (const float *) ((const char *)src->data + i1*src->nb[1]);
+ memset(dst_ptr, 0, ne0*sizeof(float));
+ for (int j = 0; j < n_add; ++j) {
+ ggml_vec_add_f32(ne0, dst_ptr, dst_ptr, data + j*ne0);
+ }
+ }
+}
+
+static void ggml_compute_forward_multi_add(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ switch (dst->type) {
+ case GGML_TYPE_F32: {
+ ggml_compute_forward_multi_add_f32(params, dst);
+ } break;
+ default: {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_add1
static void ggml_compute_forward_add1_f32(
@@ -18202,6 +18280,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_add1(params, tensor);
} break;
+ case GGML_OP_MULTI_ADD:
+ {
+ ggml_compute_forward_multi_add(params, tensor);
+ } break;
case GGML_OP_ACC:
{
ggml_compute_forward_acc(params, tensor);
@@ -18947,6 +19029,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // TODO: implement
}
+ case GGML_OP_MULTI_ADD:
+ {
+ GGML_ABORT("fatal error"); // TODO: implement
+ }
case GGML_OP_CONCAT:
{
GGML_ABORT("fatal error"); // TODO: implement
@@ -19996,6 +20082,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_ACC:
+ case GGML_OP_MULTI_ADD:
{
n_tasks = n_threads;
} break;