summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c45
1 files changed, 28 insertions, 17 deletions
diff --git a/ggml.c b/ggml.c
index bcf16222..7ff1d0b7 100644
--- a/ggml.c
+++ b/ggml.c
@@ -12311,6 +12311,20 @@ static void ggml_compute_forward_mul_mat(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
+#if GGML_USE_IQK_MULMAT
+ if (ggml_is_contiguous(src1) && dst->type == GGML_TYPE_F32) {
+ for (int64_t i13 = 0; i13 < ne13; i13++)
+ for (int64_t i12 = 0; i12 < ne12; i12++)
+ if (!iqk_mul_mat(params->type, ne01, ne11, ne00,
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
+ src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
+ (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
+ ith, nth)) goto IQK_MulMat_Not_Available1;
+ return;
+ }
+IQK_MulMat_Not_Available1:;
+#endif
+
#if GGML_USE_LLAMAFILE
const bool src1_cont = ggml_is_contiguous(src1);
@@ -12374,19 +12388,18 @@ UseGgmlGemm1:;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
#if GGML_USE_IQK_MULMAT
- if ((vec_dot_type == GGML_TYPE_Q8_K || vec_dot_type == GGML_TYPE_Q8_0 ||
- vec_dot_type == GGML_TYPE_Q8_1) && dst->type == GGML_TYPE_F32) {
+ if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) {
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
- if (!iqk_mul_mat(ne01, ne11, ne00, src0->type,
- (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
- (const char *)wdata + ggml_row_size(vec_dot_type, ne10)*(i13*ne12 + i12),
- (float *)((char *)dst->data + i12*nb2 + i13*nb3),
- nb1/ggml_type_size(dst->type),
- ith, nth)) goto IQK_MulMat_Not_Available;
+ if (!iqk_mul_mat(params->type, ne01, ne11, ne00,
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
+ vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type),
+ (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
+ ith, nth)) goto IQK_MulMat_Not_Available2;
return;
}
-IQK_MulMat_Not_Available:;
+IQK_MulMat_Not_Available2:;
#endif
@@ -12612,14 +12625,12 @@ static void ggml_compute_forward_mul_mat_id(
const int64_t nr1 = cne1; // src1 rows
//
#if GGML_USE_IQK_MULMAT
- if (ne13 == 1 && dst->type == GGML_TYPE_F32 &&
- (vec_dot_type == GGML_TYPE_Q8_K || vec_dot_type == GGML_TYPE_Q8_0 || vec_dot_type == GGML_TYPE_Q8_1)) {
- if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, src0->type,
- (const char *)src0_cur,
- (const char *)wdata,
- (float *)dst->data, nb1, nb2,
- matrix_rows + cur_a*ne12,
- ith, nth)) goto IQK_MulMat_Not_Available;
+ if (ne13 == 1 && dst->type == GGML_TYPE_F32) {
+ if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
+ src0->type, (const char *)src0_cur, nb01/ggml_type_size(src0->type),
+ vec_dot_type, (const char *)wdata, row_size/ggml_type_size(vec_dot_type),
+ (float *)dst->data, nb1, nb2,
+ matrix_rows + cur_a*ne12, ith, nth)) goto IQK_MulMat_Not_Available;
continue;
}
IQK_MulMat_Not_Available:;