summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt6
-rw-r--r--Makefile2
-rw-r--r--ggml.c11
3 files changed, 12 insertions, 7 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 158174c2..2cc0df3f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -88,6 +88,7 @@ endif()
# 3rd party libs
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
option(LLAMA_BLAS "llama: use BLAS" OFF)
+option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ON)
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
option(LLAMA_CUDA "llama: use CUDA" OFF)
option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF)
@@ -286,6 +287,7 @@ if (LLAMA_METAL)
${METALKIT_FRAMEWORK}
)
endif()
+
if (LLAMA_BLAS)
if (LLAMA_STATIC)
set(BLA_STATIC ON)
@@ -368,6 +370,10 @@ if (LLAMA_BLAS)
endif()
endif()
+if (LLAMA_LLAMAFILE)
+ add_compile_definitions(GGML_USE_LLAMAFILE)
+endif()
+
if (LLAMA_QKK_64)
add_compile_definitions(GGML_QKK_64)
endif()
diff --git a/Makefile b/Makefile
index 928fb14c..9a711743 100644
--- a/Makefile
+++ b/Makefile
@@ -222,6 +222,8 @@ endif # LLAMA_DISABLE_LOGS
# disable ggml.c's use of sgemm.cpp
ifdef LLAMA_NO_LLAMAFILE
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=0
+else
+ MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=1
endif
# warnings
diff --git a/ggml.c b/ggml.c
index 119686be..593c603f 100644
--- a/ggml.c
+++ b/ggml.c
@@ -33,12 +33,8 @@
#include <unistd.h>
#endif
-#ifndef GGML_USE_LLAMAFILE
#ifdef __ARM_FEATURE_MATMUL_INT8
-#define GGML_USE_LLAMAFILE 0
-#else
-#define GGML_USE_LLAMAFILE 1
-#endif
+#undef GGML_USE_LLAMAFILE
#endif
#if defined(_MSC_VER)
@@ -10879,8 +10875,9 @@ UseGgmlGemm1:;
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type),
- (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 +
- nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13),
+ (const char *)wdata + ggml_row_size(vec_dot_type,
+ nb12/ggml_type_size(src1->type)*i12 +
+ nb13/ggml_type_size(src1->type)*i13),
row_size/ggml_type_size(vec_dot_type),
(char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type),