summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.clang-tidy24
-rw-r--r--ggml.c45
-rw-r--r--iqk_mul_mat.cpp76
-rw-r--r--iqk_mul_mat.h10
-rw-r--r--sgemm.cpp23
5 files changed, 77 insertions, 101 deletions
diff --git a/.clang-tidy b/.clang-tidy
deleted file mode 100644
index 952c0cca..00000000
--- a/.clang-tidy
+++ /dev/null
@@ -1,24 +0,0 @@
----
-Checks: >
- bugprone-*,
- -bugprone-easily-swappable-parameters,
- -bugprone-implicit-widening-of-multiplication-result,
- -bugprone-misplaced-widening-cast,
- -bugprone-narrowing-conversions,
- readability-*,
- -readability-avoid-unconditional-preprocessor-if,
- -readability-function-cognitive-complexity,
- -readability-identifier-length,
- -readability-implicit-bool-conversion,
- -readability-magic-numbers,
- -readability-uppercase-literal-suffix,
- -readability-simplify-boolean-expr,
- clang-analyzer-*,
- -clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling,
- performance-*,
- portability-*,
- misc-*,
- -misc-const-correctness,
- -misc-non-private-member-variables-in-classes,
- -misc-no-recursion,
-FormatStyle: none
diff --git a/ggml.c b/ggml.c
index bcf16222..7ff1d0b7 100644
--- a/ggml.c
+++ b/ggml.c
@@ -12311,6 +12311,20 @@ static void ggml_compute_forward_mul_mat(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
+#if GGML_USE_IQK_MULMAT
+ if (ggml_is_contiguous(src1) && dst->type == GGML_TYPE_F32) {
+ for (int64_t i13 = 0; i13 < ne13; i13++)
+ for (int64_t i12 = 0; i12 < ne12; i12++)
+ if (!iqk_mul_mat(params->type, ne01, ne11, ne00,
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
+ src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
+ (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
+ ith, nth)) goto IQK_MulMat_Not_Available1;
+ return;
+ }
+IQK_MulMat_Not_Available1:;
+#endif
+
#if GGML_USE_LLAMAFILE
const bool src1_cont = ggml_is_contiguous(src1);
@@ -12374,19 +12388,18 @@ UseGgmlGemm1:;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
#if GGML_USE_IQK_MULMAT
- if ((vec_dot_type == GGML_TYPE_Q8_K || vec_dot_type == GGML_TYPE_Q8_0 ||
- vec_dot_type == GGML_TYPE_Q8_1) && dst->type == GGML_TYPE_F32) {
+ if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) {
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
- if (!iqk_mul_mat(ne01, ne11, ne00, src0->type,
- (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
- (const char *)wdata + ggml_row_size(vec_dot_type, ne10)*(i13*ne12 + i12),
- (float *)((char *)dst->data + i12*nb2 + i13*nb3),
- nb1/ggml_type_size(dst->type),
- ith, nth)) goto IQK_MulMat_Not_Available;
+ if (!iqk_mul_mat(params->type, ne01, ne11, ne00,
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
+ vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type),
+ (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
+ ith, nth)) goto IQK_MulMat_Not_Available2;
return;
}
-IQK_MulMat_Not_Available:;
+IQK_MulMat_Not_Available2:;
#endif
@@ -12612,14 +12625,12 @@ static void ggml_compute_forward_mul_mat_id(
const int64_t nr1 = cne1; // src1 rows
//
#if GGML_USE_IQK_MULMAT
- if (ne13 == 1 && dst->type == GGML_TYPE_F32 &&
- (vec_dot_type == GGML_TYPE_Q8_K || vec_dot_type == GGML_TYPE_Q8_0 || vec_dot_type == GGML_TYPE_Q8_1)) {
- if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, src0->type,
- (const char *)src0_cur,
- (const char *)wdata,
- (float *)dst->data, nb1, nb2,
- matrix_rows + cur_a*ne12,
- ith, nth)) goto IQK_MulMat_Not_Available;
+ if (ne13 == 1 && dst->type == GGML_TYPE_F32) {
+ if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
+ src0->type, (const char *)src0_cur, nb01/ggml_type_size(src0->type),
+ vec_dot_type, (const char *)wdata, row_size/ggml_type_size(vec_dot_type),
+ (float *)dst->data, nb1, nb2,
+ matrix_rows + cur_a*ne12, ith, nth)) goto IQK_MulMat_Not_Available;
continue;
}
IQK_MulMat_Not_Available:;
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 9934d2e6..5f01a610 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -120,7 +120,7 @@ struct MulMat {
funcs[n_left-1](n, vx, bx, info, nrc_x);
}
}
- static bool set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int Ny);
+ static bool set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
private:
template <typename Dequantizer> static void set_functions(MulMat& m);
};
@@ -173,43 +173,50 @@ const uint64_t keven_signs[128] = {
}
-bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, const void * B,
+bool iqk_mul_mat(int task_type, long Nx, long Ny, long ne00,
+ int typeA, const void * A, long strideA,
+ int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth) {
MulMat mm;
- int row_size_q8;
- if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) {
+ if (!MulMat::set_mul_mat(typeA, typeB, ne00, mm, Ny)) {
return false;
}
- auto row_size_qx = ggml_row_size((ggml_type)typeA, ne00);
+ if (ggml_task_type(task_type) != GGML_TASK_TYPE_COMPUTE) return ggml_task_type(task_type) == GGML_TASK_TYPE_INIT;
+
+ auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA));
+ auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB));
auto nrc_x = (Nx + nth - 1)/nth;
auto first_x = ith*nrc_x;
if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
- DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, (size_t)row_size_q8, 0, 1, nullptr, 0};
+ DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
return true;
}
-bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, const void * B,
+bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
+ int typeA, const void * A, long strideA,
+ int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;
assert(row_mapping != nullptr);
MulMat mm;
- int row_size_q8;
- if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) {
+ if (!MulMat::set_mul_mat(typeA, typeB, ne00, mm, Ny)) {
return false;
}
- int row_size_qx = ggml_row_size((ggml_type)typeA, ne00);
+ auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA));
+ auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB));
int nrc_x = (Nx + nth - 1)/nth;
int first_x = ith*nrc_x;
if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
- DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), (size_t)row_size_q8, 0, ne11, row_mapping, nb2/sizeof(float)};
+ DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
+ row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
return true;
}
@@ -236,7 +243,6 @@ inline float hsum_float_8(__m256 x) {
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
-
template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
constexpr static int nrc_y = nrc;
@@ -2394,14 +2400,14 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
}
}
-bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int Ny) {
+bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
//if (Ny == 1 && (typeA == GGML_TYPE_IQ3_S || typeA == GGML_TYPE_IQ3_XXS)) {
if (Ny == 999 && typeA == GGML_TYPE_IQ3_S) {
return false;
}
- if (typeA == GGML_TYPE_F16) {
+ if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F32) {
for (auto& f : mm.funcs) f = nullptr;
mm.funcs[0] = mul_mat_fX_fY_T<1, ggml_half, float>;
mm.funcs[1] = mul_mat_fX_fY_T<2, ggml_half, float>;
@@ -2411,10 +2417,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
#ifndef __AVX512F__
mm.funcs[5] = mul_mat_fX_fY_T<6, ggml_half, float>;
#endif
- row_size_q8 = ggml_row_size(GGML_TYPE_F32, ne00);
return true;
}
- if (typeA == GGML_TYPE_F32) {
+ if (typeA == GGML_TYPE_F32 && typeB == GGML_TYPE_F16) {
for (auto& f : mm.funcs) f = nullptr;
mm.funcs[0] = mul_mat_fX_fY_T<1, float, ggml_half>;
mm.funcs[1] = mul_mat_fX_fY_T<2, float, ggml_half>;
@@ -2424,7 +2429,6 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
#ifndef __AVX512F__
mm.funcs[5] = mul_mat_fX_fY_T<6, float, ggml_half>;
#endif
- row_size_q8 = ggml_row_size(GGML_TYPE_F16, ne00);
return true;
}
// Using the standard legacy quant template is slightly faster than tiling
@@ -2441,7 +2445,7 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
// return true;
// }
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);
+ auto expected_typeB = GGML_TYPE_Q8_K;
switch (typeA) {
case GGML_TYPE_Q2_K:
@@ -2491,33 +2495,35 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0);
MulMat::set_functions<Q4_0_Unpacker>(mm);
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
+ expected_typeB = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q4_1:
assert (ne00 % QK4_1 == 0);
MulMat::set_functions<Q4_1_Unpacker>(mm);
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);
+ expected_typeB = GGML_TYPE_Q8_1;
break;
case GGML_TYPE_Q5_0:
assert (ne00 % QK5_0 == 0);
MulMat::set_functions<Q5_0_Unpacker>(mm);
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
+ expected_typeB = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q5_1:
assert (ne00 % QK5_1 == 0);
MulMat::set_functions<Q5_1_Unpacker>(mm);
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);
+ expected_typeB = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q8_0:
assert (ne00 % QK8_0 == 0);
MulMat::set_functions<Q8_0_Unpacker>(mm);
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
+ expected_typeB = GGML_TYPE_Q8_0;
break;
default:
return false;
}
+ if (typeB != expected_typeB) return false;
+
return true;
}
@@ -3882,7 +3888,7 @@ IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0
template <int nrc_y>
void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n%QF16Base::k_step == 0);
+ GGML_ASSERT(n%QF16Base::k_step == 0);
constexpr int k_nx = 5;
const char * cx = (const char *)vx;
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
@@ -3933,10 +3939,10 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
}
}
-bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /*Ny*/) {
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);
+bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
- if (typeA == GGML_TYPE_F16) {
+ if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F16) {
+ if (ne00%8) return false;
for (auto& f : m.funcs) f = nullptr;
m.funcs[0] = mul_mat_f16_f16_T<1>;
m.funcs[1] = mul_mat_f16_f16_T<2>;
@@ -3945,10 +3951,11 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /
m.funcs[4] = mul_mat_f16_f16_T<5>;
//m.funcs[5] = mul_mat_f16_f16_T<6>;
//m.funcs[6] = mul_mat_f16_f16_T<7>;
- row_size_q8 = ggml_row_size(GGML_TYPE_F16, ne00);
return true;
}
+ auto expected_Btype = GGML_TYPE_Q8_K;
+
switch (typeA) {
case GGML_TYPE_Q2_K:
MulMat::set_functions<DequantizerQ2K>(m);
@@ -3985,28 +3992,29 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /
break;
case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m);
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
+ expected_Btype = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q4_1:
MulMat::set_functions<DequantizerQ41>(m);
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);
+ expected_Btype = GGML_TYPE_Q8_1;
break;
case GGML_TYPE_Q5_0:
MulMat::set_functions<DequantizerQ50>(m);
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
+ expected_Btype = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q5_1:
MulMat::set_functions<DequantizerQ51>(m);
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);
+ expected_Btype = GGML_TYPE_Q8_1;
break;
case GGML_TYPE_Q8_0:
MulMat::set_functions<DequantizerQ80>(m);
- row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
+ expected_Btype = GGML_TYPE_Q8_0;
break;
default:
return false;
}
- return true;
+
+ return typeB == expected_Btype;
}
}
diff --git a/iqk_mul_mat.h b/iqk_mul_mat.h
index 4706714b..c1db5eee 100644
--- a/iqk_mul_mat.h
+++ b/iqk_mul_mat.h
@@ -5,11 +5,15 @@
extern "C" {
#endif
-bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, const void * B,
+bool iqk_mul_mat(int task_type, long Nx, long Ny, long ne00,
+ int typeA, const void * A, long strideA,
+ int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth);
-bool iqk_mul_mat_moe(long, long, long, int, int, const void *, const void *,
- float *, long, long, const void *, int, int);
+bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
+ int typeA, const void * A, long strideA,
+ int typeB, const void * B, long strideB,
+ float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
#ifdef __cplusplus
diff --git a/sgemm.cpp b/sgemm.cpp
index 409a9a67..c7189c61 100644
--- a/sgemm.cpp
+++ b/sgemm.cpp
@@ -51,9 +51,6 @@
#include "sgemm.h"
#include "ggml-impl.h"
#include "ggml-quants.h"
-#if GGML_USE_IQK_MULMAT
-#include "iqk_mul_mat.h"
-#endif
#ifdef _MSC_VER
#define NOINLINE __declspec(noinline)
@@ -868,26 +865,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
if (Ctype != GGML_TYPE_F32)
return false;
-#if GGML_USE_IQK_MULMAT
-#if defined __AVX2__ && defined __FMA__
- bool is_accepted_float_type = k >= 32 &&
- ((Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F32) || (Atype == GGML_TYPE_F32 && Btype == GGML_TYPE_F16));
-#elif defined __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && defined __ARM_FEATURE_FMA
- bool is_accepted_float_type = k >= 32 && Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F16;
-#else
- bool is_accepted_float_type = false;
-#endif
- if (task == GGML_TASK_TYPE_INIT && is_accepted_float_type) {
- return true;
- }
-
- if (task == GGML_TASK_TYPE_COMPUTE && is_accepted_float_type) {
- if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
- return true;
- }
- }
-#endif
-
switch (Atype) {
case GGML_TYPE_F32: {