diff options
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 54 |
1 files changed, 54 insertions, 0 deletions
@@ -4,6 +4,7 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "ggml.h" +#include "sgemm.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include <malloc.h> // using malloc.h with MSC/MINGW @@ -32,6 +33,14 @@ #include <unistd.h> #endif +#ifndef GGML_USE_LLAMAFILE +#ifdef __ARM_FEATURE_MATMUL_INT8 +#define GGML_USE_LLAMAFILE 0 +#else +#define GGML_USE_LLAMAFILE 1 +#endif +#endif + #if defined(_MSC_VER) // disable "possible loss of data" to avoid hundreds of casts // we should just be careful :) @@ -10810,6 +10819,28 @@ static void ggml_compute_forward_mul_mat( } #endif +#if GGML_USE_LLAMAFILE + if (nb10 == ggml_type_size(src1->type)) { + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, + nb01/ggml_type_size(src0->type), + (const char *)src1->data + i12*nb12 + i13*nb13, + nb11/ggml_type_size(src1->type), + (char *)dst->data + i12*nb2 + i13*nb3, + nb1/ggml_type_size(dst->type), + ith, nth, + params->type, + src0->type, + src1->type, + dst->type)) + goto UseGgmlGemm1; + return; + } +UseGgmlGemm1:; +#endif + if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { return; @@ -10841,6 +10872,29 @@ static void ggml_compute_forward_mul_mat( const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); +#if GGML_USE_LLAMAFILE + if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) { + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, + nb01/ggml_type_size(src0->type), + (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 + + nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13), + row_size/ggml_type_size(vec_dot_type), + (char *)dst->data + i12*nb2 + i13*nb3, + nb1/ggml_type_size(dst->type), + ith, nth, + params->type, + src0->type, + vec_dot_type, + dst->type)) + goto UseGgmlGemm2; + return; + } +UseGgmlGemm2:; +#endif + const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = ne1*ne12*ne13; // src1 rows |