summaryrefslogtreecommitdiff
path: root/sgemm.cpp
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-10 16:43:42 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:50 +0300
commit81cf6990f512e82c2c89ba7f89a15c3d98172f84 (patch)
treeb5d6af3449bf6b1a865a103e3ffc88d5e057b3f6 /sgemm.cpp
parentb2acd81c753a098ad8dfb7acf0daf8aebf0ee79a (diff)
iqk_mul_mat: be able to handle any f16/f32 combination on AVX2
But only turning on f16 x f32 and f32 x f16 for now.
Diffstat (limited to 'sgemm.cpp')
-rw-r--r--sgemm.cpp41
1 files changed, 30 insertions, 11 deletions
diff --git a/sgemm.cpp b/sgemm.cpp
index a16752f0..b6c00c4e 100644
--- a/sgemm.cpp
+++ b/sgemm.cpp
@@ -866,22 +866,41 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
if (Ctype != GGML_TYPE_F32)
return false;
- if (task == GGML_TASK_TYPE_COMPUTE && k >= 256 && Atype == GGML_TYPE_F16) {
#if defined __AVX2__ && defined __FMA__
- if (Btype == GGML_TYPE_F32) {
- if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
- return true;
- }
- }
+ //bool is_accepted_float_type = k >= 32 && Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F32;
+ bool is_accepted_float_type = k >= 32 &&
+ ((Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F32) || (Atype == GGML_TYPE_F32 && Btype == GGML_TYPE_F16));
#elif defined __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && defined __ARM_FEATURE_FMA
- if (Btype == GGML_TYPE_F16) {
- if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
- return true;
- }
- }
+ bool is_accepted_float_type = k >= 32 && Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F16;
+#else
+ bool is_accepted_float_type = false;
#endif
+ if (task == GGML_TASK_TYPE_INIT && is_accepted_float_type) {
+ return true;
}
+ if (task == GGML_TASK_TYPE_COMPUTE && is_accepted_float_type) {
+ if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
+ return true;
+ }
+ }
+
+// if (task == GGML_TASK_TYPE_COMPUTE && k >= 32 && Atype == GGML_TYPE_F16) {
+//#if defined __AVX2__ && defined __FMA__
+// if (Btype == GGML_TYPE_F32) {
+// if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
+// return true;
+// }
+// }
+//#elif defined __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && defined __ARM_FEATURE_FMA
+// if (Btype == GGML_TYPE_F16) {
+// if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
+// return true;
+// }
+// }
+//#endif
+// }
+
switch (Atype) {
case GGML_TYPE_F32: {