summaryrefslogtreecommitdiff
path: root/iqk_mul_mat.cpp
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-11 10:33:51 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:50 +0300
commit7501184eb4d5f9cb12a160919f6603d75c6bc529 (patch)
tree22ec63ecd544258b6b6c6c2999c8512519f6638e /iqk_mul_mat.cpp
parentad53eabf87816705f49501a54c7555c337bb47ce (diff)
iqk_mul_mat: be independent of llamafile_sgemm
Verified that it works on AVX2. Also turned on any combination of f16 and f32 (i.e., added f16 x 16 and f32 x f32).
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r--iqk_mul_mat.cpp72
1 files changed, 31 insertions, 41 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 5f01a610..d9583e8b 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -2400,50 +2400,42 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
}
}
+template <typename FloatX, typename FloatY>
+void set_mul_mat_f(MulMat& mm) {
+ for (auto& f : mm.funcs) f = nullptr;
+ mm.funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>;
+ mm.funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>;
+ mm.funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>;
+ mm.funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>;
+ mm.funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>;
+#ifndef __AVX512F__
+ mm.funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>;
+#endif
+}
+
bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
- //if (Ny == 1 && (typeA == GGML_TYPE_IQ3_S || typeA == GGML_TYPE_IQ3_XXS)) {
- if (Ny == 999 && typeA == GGML_TYPE_IQ3_S) {
- return false;
- }
+ (void)Ny;
- if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F32) {
- for (auto& f : mm.funcs) f = nullptr;
- mm.funcs[0] = mul_mat_fX_fY_T<1, ggml_half, float>;
- mm.funcs[1] = mul_mat_fX_fY_T<2, ggml_half, float>;
- mm.funcs[2] = mul_mat_fX_fY_T<3, ggml_half, float>;
- mm.funcs[3] = mul_mat_fX_fY_T<4, ggml_half, float>;
- mm.funcs[4] = mul_mat_fX_fY_T<5, ggml_half, float>;
-#ifndef __AVX512F__
- mm.funcs[5] = mul_mat_fX_fY_T<6, ggml_half, float>;
-#endif
+ if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) {
+ if (ne00 % QFBase::k_step) return false;
+ }
+ if (typeA == GGML_TYPE_F16) {
+ switch (typeB) {
+ case GGML_TYPE_F16: set_mul_mat_f<ggml_half, ggml_half>(mm); break;
+ case GGML_TYPE_F32: set_mul_mat_f<ggml_half, float>(mm); break;
+ default: return false;
+ }
return true;
}
- if (typeA == GGML_TYPE_F32 && typeB == GGML_TYPE_F16) {
- for (auto& f : mm.funcs) f = nullptr;
- mm.funcs[0] = mul_mat_fX_fY_T<1, float, ggml_half>;
- mm.funcs[1] = mul_mat_fX_fY_T<2, float, ggml_half>;
- mm.funcs[2] = mul_mat_fX_fY_T<3, float, ggml_half>;
- mm.funcs[3] = mul_mat_fX_fY_T<4, float, ggml_half>;
- mm.funcs[4] = mul_mat_fX_fY_T<5, float, ggml_half>;
-#ifndef __AVX512F__
- mm.funcs[5] = mul_mat_fX_fY_T<6, float, ggml_half>;
-#endif
+ if (typeA == GGML_TYPE_F32) {
+ switch (typeB) {
+ case GGML_TYPE_F16: set_mul_mat_f<float, ggml_half>(mm); break;
+ case GGML_TYPE_F32: set_mul_mat_f<float, float>(mm); break;
+ default: return false;
+ }
return true;
}
- // Using the standard legacy quant template is slightly faster than tiling
- // as implemented in mul_mat_q80_q80_T
-// if (typeA == GGML_TYPE_Q8_0) {
-// for (auto& f : mm.funcs) f = nullptr;
-// mm.funcs[0] = mul_mat_q80_q80_T<1>;
-// mm.funcs[1] = mul_mat_q80_q80_T<2>;
-// mm.funcs[2] = mul_mat_q80_q80_T<3>;
-//#ifdef __AVX512F__
-// mm.funcs[3] = mul_mat_q80_q80_T<4>;
-//#endif
-// row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
-// return true;
-// }
auto expected_typeB = GGML_TYPE_Q8_K;
@@ -2510,7 +2502,7 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_Q5_1:
assert (ne00 % QK5_1 == 0);
MulMat::set_functions<Q5_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_0;
+ expected_typeB = GGML_TYPE_Q8_1;
break;
case GGML_TYPE_Q8_0:
assert (ne00 % QK8_0 == 0);
@@ -2522,9 +2514,7 @@ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
return false;
}
- if (typeB != expected_typeB) return false;
-
- return true;
+ return ggml_type(typeB) == expected_typeB;
}
} // namespace