diff options
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 57 |
1 files changed, 23 insertions, 34 deletions
@@ -4573,45 +4573,38 @@ void ggml_mul_mat_set_prec( // ggml_mul_mat_id +// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed +// this will allow computing all the used experts in a single matrix multiplication struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, - struct ggml_tensor * const as[], - int n_as, + struct ggml_tensor * as, struct ggml_tensor * ids, int id, struct ggml_tensor * b) { GGML_ASSERT(ids->type == GGML_TYPE_I32); - GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); - GGML_ASSERT(ids->ne[1] == b->ne[1]); + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d + GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]); - GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2); - GGML_ASSERT(id >= 0 && id < ids->ne[0]); + GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id + GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat bool is_node = false; - if (as[0]->grad || b->grad) { + if (as->grad || b->grad) { is_node = true; } - const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); ggml_set_op_params_i32(result, 0, id); - ggml_set_op_params_i32(result, 1, n_as); result->op = GGML_OP_MUL_MAT_ID; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = ids; + result->src[0] = as; result->src[1] = b; - - for (int i = 0; i < n_as; i++) { - struct ggml_tensor * a = as[i]; - GGML_ASSERT(ggml_are_same_shape(as[0], a)); - GGML_ASSERT(ggml_can_mul_mat(a, b)); - GGML_ASSERT(!ggml_is_transposed(a)); - result->src[i + 2] = a; - } + result->src[2] = ids; return result; } @@ -10948,10 +10941,9 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * ids = dst->src[0]; + const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - - const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS + const struct ggml_tensor * ids = dst->src[2]; GGML_TENSOR_BINARY_OP_LOCALS @@ -10981,13 +10973,13 @@ static void ggml_compute_forward_mul_mat_id( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; + // broadcast is not supported with mmid + assert(ne12 == 1); + assert(ne13 == 1); // row groups const int id = ggml_get_op_params_i32(dst, 0); - const int n_as = ggml_get_op_params_i32(dst, 1); + const int n_as = src0->ne[2]; char * wdata_src1_end = (src1->type == vec_dot_type) ? (char *) params->wdata : @@ -11047,7 +11039,7 @@ static void ggml_compute_forward_mul_mat_id( continue; } - const struct ggml_tensor * src0_cur = dst->src[cur_a + 2]; + size_t src0_offset = cur_a*src0->nb[2]; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); @@ -11082,9 +11074,6 @@ static void ggml_compute_forward_mul_mat_id( continue; } - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); - // block-tiling attempt const int64_t blck_0 = 16; const int64_t blck_1 = 16; @@ -11101,14 +11090,14 @@ static void ggml_compute_forward_mul_mat_id( const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11); // broadcast src0 into src1 - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; + //const int64_t i03 = i13/r3; + //const int64_t i02 = i12/r2; const int64_t i1 = i11; const int64_t i2 = i12; const int64_t i3 = i13; - const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03); + const char * src0_row = (const char *) src0->data + src0_offset; // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using @@ -18464,13 +18453,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa case GGML_OP_MUL_MAT_ID: { cur = 0; - const struct ggml_tensor * src0 = node->src[2]; + const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src1 = node->src[1]; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; if (src1->type != vec_dot_type) { cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); } - const int n_as = ggml_get_op_params_i32(node, 1); + const int n_as = src0->ne[2]; cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows |