summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c57
1 files changed, 23 insertions, 34 deletions
diff --git a/ggml.c b/ggml.c
index 7471e792..c9b0a6a0 100644
--- a/ggml.c
+++ b/ggml.c
@@ -4573,45 +4573,38 @@ void ggml_mul_mat_set_prec(
// ggml_mul_mat_id
+// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
+// this will allow computing all the used experts in a single matrix multiplication
struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
- struct ggml_tensor * const as[],
- int n_as,
+ struct ggml_tensor * as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b) {
GGML_ASSERT(ids->type == GGML_TYPE_I32);
- GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
- GGML_ASSERT(ids->ne[1] == b->ne[1]);
+ GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
+ GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
- GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
- GGML_ASSERT(id >= 0 && id < ids->ne[0]);
+ GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
+ GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
bool is_node = false;
- if (as[0]->grad || b->grad) {
+ if (as->grad || b->grad) {
is_node = true;
}
- const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+ const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
ggml_set_op_params_i32(result, 0, id);
- ggml_set_op_params_i32(result, 1, n_as);
result->op = GGML_OP_MUL_MAT_ID;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src[0] = ids;
+ result->src[0] = as;
result->src[1] = b;
-
- for (int i = 0; i < n_as; i++) {
- struct ggml_tensor * a = as[i];
- GGML_ASSERT(ggml_are_same_shape(as[0], a));
- GGML_ASSERT(ggml_can_mul_mat(a, b));
- GGML_ASSERT(!ggml_is_transposed(a));
- result->src[i + 2] = a;
- }
+ result->src[2] = ids;
return result;
}
@@ -10948,10 +10941,9 @@ static void ggml_compute_forward_mul_mat_id(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
- const struct ggml_tensor * ids = dst->src[0];
+ const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
-
- const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS
+ const struct ggml_tensor * ids = dst->src[2];
GGML_TENSOR_BINARY_OP_LOCALS
@@ -10981,13 +10973,13 @@ static void ggml_compute_forward_mul_mat_id(
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
- // broadcast factors
- const int64_t r2 = ne12/ne02;
- const int64_t r3 = ne13/ne03;
+ // broadcast is not supported with mmid
+ assert(ne12 == 1);
+ assert(ne13 == 1);
// row groups
const int id = ggml_get_op_params_i32(dst, 0);
- const int n_as = ggml_get_op_params_i32(dst, 1);
+ const int n_as = src0->ne[2];
char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata :
@@ -11047,7 +11039,7 @@ static void ggml_compute_forward_mul_mat_id(
continue;
}
- const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
+ size_t src0_offset = cur_a*src0->nb[2];
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -11082,9 +11074,6 @@ static void ggml_compute_forward_mul_mat_id(
continue;
}
- assert(ne12 % ne02 == 0);
- assert(ne13 % ne03 == 0);
-
// block-tiling attempt
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
@@ -11101,14 +11090,14 @@ static void ggml_compute_forward_mul_mat_id(
const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
// broadcast src0 into src1
- const int64_t i03 = i13/r3;
- const int64_t i02 = i12/r2;
+ //const int64_t i03 = i13/r3;
+ //const int64_t i02 = i12/r2;
const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;
- const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03);
+ const char * src0_row = (const char *) src0->data + src0_offset;
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
@@ -18464,13 +18453,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
case GGML_OP_MUL_MAT_ID:
{
cur = 0;
- const struct ggml_tensor * src0 = node->src[2];
+ const struct ggml_tensor * src0 = node->src[0];
const struct ggml_tensor * src1 = node->src[1];
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) {
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
}
- const int n_as = ggml_get_op_params_i32(node, 1);
+ const int n_as = src0->ne[2];
cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows