From 81cf6990f512e82c2c89ba7f89a15c3d98172f84 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 10 Jun 2024 16:43:42 +0300 Subject: iqk_mul_mat: be able to handle any f16/f32 combination on AVX2 But only turning on f16 x f32 and f32 x f16 for now. --- sgemm.cpp | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) (limited to 'sgemm.cpp') diff --git a/sgemm.cpp b/sgemm.cpp index a16752f0..b6c00c4e 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -866,22 +866,41 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda if (Ctype != GGML_TYPE_F32) return false; - if (task == GGML_TASK_TYPE_COMPUTE && k >= 256 && Atype == GGML_TYPE_F16) { #if defined __AVX2__ && defined __FMA__ - if (Btype == GGML_TYPE_F32) { - if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) { - return true; - } - } + //bool is_accepted_float_type = k >= 32 && Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F32; + bool is_accepted_float_type = k >= 32 && + ((Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F32) || (Atype == GGML_TYPE_F32 && Btype == GGML_TYPE_F16)); #elif defined __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && defined __ARM_FEATURE_FMA - if (Btype == GGML_TYPE_F16) { - if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) { - return true; - } - } + bool is_accepted_float_type = k >= 32 && Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F16; +#else + bool is_accepted_float_type = false; #endif + if (task == GGML_TASK_TYPE_INIT && is_accepted_float_type) { + return true; } + if (task == GGML_TASK_TYPE_COMPUTE && is_accepted_float_type) { + if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) { + return true; + } + } + +// if (task == GGML_TASK_TYPE_COMPUTE && k >= 32 && Atype == GGML_TYPE_F16) { +//#if defined __AVX2__ && defined __FMA__ +// if (Btype == GGML_TYPE_F32) { +// if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) { +// return true; +// } +// } +//#elif defined __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && defined __ARM_FEATURE_FMA +// if (Btype == GGML_TYPE_F16) { +// if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) { +// return true; +// } +// } +//#endif +// } + switch (Atype) { case GGML_TYPE_F32: { -- cgit v1.2.3