diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 91 |
1 files changed, 89 insertions, 2 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5570b1fc..39218ff4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3338,6 +3338,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GROUP_NORM", "FUSED_RMS_NORM", "FUSED_MUL_UNARY", + "MULTI_ADD", "MUL_MAT", "MUL_MAT_ID", @@ -3401,7 +3402,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); +static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3430,6 +3431,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "group_norm(x)", "fused_rms_norm(x)", "fused_mul_unary(x)", + "x1+x2+x3+...", "X*Y", "X[i]*Y", @@ -3493,7 +3495,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); +static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5106,6 +5108,29 @@ struct ggml_tensor * ggml_add_inplace( return ggml_add_impl(ctx, a, b, true); } +// ggml_add + +struct ggml_tensor * ggml_multi_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_experts) { + + bool is_node = false; + + if (n_experts < 1) { + GGML_ABORT("fatal error"); + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_MULTI_ADD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->op_params[0] = n_experts; + + return result; +} + // ggml_add_cast static struct ggml_tensor * ggml_add_cast_impl( @@ -10425,6 +10450,59 @@ static void ggml_compute_forward_add( } } +static void ggml_compute_forward_multi_add_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + struct ggml_tensor * src = dst->src[0]; + + GGML_ASSERT(dst->nb[0] == sizeof(float)); + GGML_ASSERT(src->nb[0] == sizeof(float)); + GGML_ASSERT(ggml_are_same_shape(src, dst)); + GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1); + + const int n_add = dst->op_params[0]; + GGML_ASSERT(n_add > 0); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(dst); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + int64_t ne0 = dst->ne[0]; + + for (int i1 = ir0; i1 < ir1; ++i1) { + + float * dst_ptr = (float *) ((char *) dst->data + i1*dst->nb[1] ); + const float * data = (const float *) ((const char *)src->data + i1*src->nb[1]); + memset(dst_ptr, 0, ne0*sizeof(float)); + for (int j = 0; j < n_add; ++j) { + ggml_vec_add_f32(ne0, dst_ptr, dst_ptr, data + j*ne0); + } + } +} + +static void ggml_compute_forward_multi_add( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + switch (dst->type) { + case GGML_TYPE_F32: { + ggml_compute_forward_multi_add_f32(params, dst); + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_add1 static void ggml_compute_forward_add1_f32( @@ -18202,6 +18280,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_add1(params, tensor); } break; + case GGML_OP_MULTI_ADD: + { + ggml_compute_forward_multi_add(params, tensor); + } break; case GGML_OP_ACC: { ggml_compute_forward_acc(params, tensor); @@ -18947,6 +19029,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: implement } + case GGML_OP_MULTI_ADD: + { + GGML_ABORT("fatal error"); // TODO: implement + } case GGML_OP_CONCAT: { GGML_ABORT("fatal error"); // TODO: implement @@ -19996,6 +20082,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ADD: case GGML_OP_ADD1: case GGML_OP_ACC: + case GGML_OP_MULTI_ADD: { n_tasks = n_threads; } break; |