summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp42
1 files changed, 30 insertions, 12 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index ce868514..a70310d4 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -7492,7 +7492,11 @@ struct FlashQKfp32 {
#ifdef __aarch64__
mul_mat_qX_0_q8_0<DequantizerQ80_x4, q_step>(D, kh.block, kh.stride, info, k_step);
#else
- mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+ if constexpr (D >= 128) {
+ mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+ } else {
+ mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+ }
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
@@ -7523,7 +7527,7 @@ struct FlashQKfp32 {
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
GGML_ASSERT(nq < 8);
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
switch (nq) {
#ifdef __aarch64__
case 1: mul_mat_qX_0_q8_0<DequantizerQ40, 1>(D, kh.block, kh.stride, info, k_step); break;
@@ -7545,9 +7549,9 @@ struct FlashQKfp32 {
}
}
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
- switch (nq) {
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
#ifdef __aarch64__
+ switch (nq) {
case 1: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 3>(D, kh.block, kh.stride, info, k_step); break;
@@ -7555,16 +7559,30 @@ struct FlashQKfp32 {
case 5: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 7>(D, kh.block, kh.stride, info, k_step); break;
+ }
#else
- case 1: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
-#endif
+ if constexpr (D >= 128) {
+ switch (nq) {
+ case 1: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
+ }
+ } else {
+ switch (nq) {
+ case 1: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
+ }
}
+#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};