diff options
Diffstat (limited to 'sgemm.cpp')
-rw-r--r-- | sgemm.cpp | 1033 |
1 files changed, 0 insertions, 1033 deletions
diff --git a/sgemm.cpp b/sgemm.cpp deleted file mode 100644 index c7189c61..00000000 --- a/sgemm.cpp +++ /dev/null @@ -1,1033 +0,0 @@ -// Copyright 2024 Mozilla Foundation -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS -// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN -// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// -// _ _ ___ _ _ ___ -// | |_(_)_ _ _ _| _ ) | /_\ / __| -// | _| | ' \ || | _ \ |__ / _ \\__ \. -// \__|_|_||_\_, |___/____/_/ \_\___/ -// |__/ -// -// BASIC LINEAR ALGEBRA SUBPROGRAMS -// -// -// This file implements multithreaded CPU matrix multiplication for the -// common contiguous use case C = Aᵀ * B. These kernels are designed to -// have excellent performance[1] for matrices that fit in the CPU cache -// without imposing any overhead such as cache filling or malloc calls. -// -// This implementation does not guarantee any upper bound with rounding -// errors, which grow along with k. Our goal's to maximally exploit the -// hardware for performance, and then use whatever resources remain for -// improving numerical accuracy. -// -// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. -// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. - -#if defined(__GNUC__) -#pragma GCC diagnostic ignored "-Wpedantic" -#pragma GCC diagnostic ignored "-Wignored-attributes" -#endif - -#include "sgemm.h" -#include "ggml-impl.h" -#include "ggml-quants.h" - -#ifdef _MSC_VER -#define NOINLINE __declspec(noinline) -#else -#define NOINLINE __attribute__((__noinline__)) -#endif - -#if defined(__ARM_NEON) || defined(__AVX512F__) -#define VECTOR_REGISTERS 32 -#else -#define VECTOR_REGISTERS 16 -#endif - -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - -namespace { - -inline float unhalf(ggml_fp16_t d) { - return GGML_FP16_TO_FP32(d); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// VECTORIZED ARITHMETIC OPERATIONS - -#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); } -inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); } -inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); } -#endif // __SSE__ - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); } -inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); } -inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); } -#endif // __AVX__ - -#if defined(__AVX512F__) -inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); } -inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); } -inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); } -#endif // __AVX512F__ - -#if defined(__ARM_NEON) -inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); } -inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); } -inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); } -#endif // __ARM_NEON - -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) -inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); } -inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } -inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// VECTORIZED FUSED MULTIPLY ADD - -/** - * Computes a * b + c. - */ -template <typename T, typename U> -inline U madd(T a, T b, U c) { - return add(mul(a, b), c); -} - -#if defined(__FMA__) -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -template <> -inline __m256 madd(__m256 a, __m256 b, __m256 c) { - return _mm256_fmadd_ps(a, b, c); -} -#endif -#if defined(__AVX512F__) -template <> -inline __m512 madd(__m512 a, __m512 b, __m512 c) { - return _mm512_fmadd_ps(a, b, c); -} -#endif -#endif - -#if defined(__ARM_FEATURE_FMA) -template <> -inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { - return vfmaq_f32(c, b, a); -} -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) -template <> -inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { - return vfmaq_f16(c, b, a); -} -#endif -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// VECTORIZED HORIZONTAL SUM - -#if defined(__ARM_NEON) -inline float hsum(float32x4_t x) { - return vaddvq_f32(x); -} -#endif // __ARM_NEON - -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) -inline float hsum(float16x8_t x) { - return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), - vcvt_f32_f16(vget_high_f16(x)))); -} -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -inline float hsum(__m128 x) { -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) - x = _mm_add_ps(x, _mm_movehl_ps(x, x)); - x = _mm_add_ss(x, _mm_movehdup_ps(x)); -#else - __m128 t; - t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1)); - x = _mm_add_ps(x, t); - t = _mm_movehl_ps(t, x); - x = _mm_add_ss(x, t); -#endif - return _mm_cvtss_f32(x); -} -#endif - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -inline float hsum(__m256 x) { - return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), - _mm256_castps256_ps128(x))); -} -#endif // __AVX__ - -#if defined(__AVX512F__) -inline float hsum(__m512 x) { - return _mm512_reduce_add_ps(x); -} -#endif // __AVX512F__ - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// VECTORIZED MEMORY LOADING - -template <typename T, typename U> T load(const U *); - -#if defined(__ARM_NEON) -template <> inline float32x4_t load(const float *p) { - return vld1q_f32(p); -} -#if !defined(_MSC_VER) -template <> inline float16x8_t load(const ggml_fp16_t *p) { - return vld1q_f16((const float16_t *)p); -} -template <> inline float32x4_t load(const ggml_fp16_t *p) { - return vcvt_f32_f16(vld1_f16((const float16_t *)p)); -} -#endif // _MSC_VER -#endif // __ARM_NEON - -#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -template <> inline __m128 load(const float *p) { - return _mm_loadu_ps(p); -} -#endif // __SSE__ - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -template <> inline __m256 load(const float *p) { - return _mm256_loadu_ps(p); -} -#endif // __AVX__ - -#if defined(__F16C__) -template <> inline __m256 load(const ggml_fp16_t *p) { - return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p)); -} -#endif // __F16C__ - -#if defined(__AVX512F__) -template <> inline __m512 load(const float *p) { - return _mm512_loadu_ps(p); -} -template <> inline __m512 load(const ggml_fp16_t *p) { - return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p)); -} -#endif // __AVX512F__ - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// FLOATING POINT MATRIX MULTIPLICATION - -template <int KN, typename D, typename V, typename TA, typename TB, typename TC> -class tinyBLAS { - public: - tinyBLAS(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, - int ith, int nth) - : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { - } - - void matmul(int64_t m, int64_t n, int task) { - if (task == GGML_TASK_TYPE_COMPUTE) - mnpack(0, m, 0, n); - } - - private: - NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { -#if VECTOR_REGISTERS == 32 - case 0x55: - mc = 5; - nc = 5; - gemm<5, 5>(m0, m, n0, n); - break; - case 0x45: - mc = 4; - nc = 5; - gemm<4, 5>(m0, m, n0, n); - break; - case 0x54: - mc = 5; - nc = 4; - gemm<5, 4>(m0, m, n0, n); - break; - case 0x44: - mc = 4; - nc = 4; - gemm<4, 4>(m0, m, n0, n); - break; - case 0x53: - mc = 5; - nc = 3; - gemm<5, 3>(m0, m, n0, n); - break; - case 0x35: - mc = 3; - nc = 5; - gemm<3, 5>(m0, m, n0, n); - break; - case 0x43: - mc = 4; - nc = 3; - gemm<4, 3>(m0, m, n0, n); - break; -#else - case 0x55: - case 0x54: - case 0x53: - case 0x45: - case 0x44: - case 0x43: - mc = 4; - nc = 3; - gemm<4, 3>(m0, m, n0, n); - break; - case 0x35: -#endif - case 0x34: - mc = 3; - nc = 4; - gemm<3, 4>(m0, m, n0, n); - break; - case 0x52: - mc = 5; - nc = 2; - gemm<5, 2>(m0, m, n0, n); - break; - case 0x33: - mc = 3; - nc = 3; - gemm<3, 3>(m0, m, n0, n); - break; - case 0x25: - mc = 2; - nc = 5; - gemm<2, 5>(m0, m, n0, n); - break; - case 0x42: - mc = 4; - nc = 2; - gemm<4, 2>(m0, m, n0, n); - break; - case 0x24: - mc = 2; - nc = 4; - gemm<2, 4>(m0, m, n0, n); - break; - case 0x32: - mc = 3; - nc = 2; - gemm<3, 2>(m0, m, n0, n); - break; - case 0x23: - mc = 2; - nc = 3; - gemm<2, 3>(m0, m, n0, n); - break; - case 0x51: - mc = 5; - nc = 1; - gemm<5, 1>(m0, m, n0, n); - break; - case 0x41: - mc = 4; - nc = 1; - gemm<4, 1>(m0, m, n0, n); - break; - case 0x22: - mc = 2; - nc = 2; - gemm<2, 2>(m0, m, n0, n); - break; - case 0x15: - mc = 1; - nc = 5; - gemm<1, 5>(m0, m, n0, n); - break; - case 0x14: - mc = 1; - nc = 4; - gemm<1, 4>(m0, m, n0, n); - break; - case 0x31: - mc = 3; - nc = 1; - gemm<3, 1>(m0, m, n0, n); - break; - case 0x13: - mc = 1; - nc = 3; - gemm<1, 3>(m0, m, n0, n); - break; - case 0x21: - mc = 2; - nc = 1; - gemm<2, 1>(m0, m, n0, n); - break; - case 0x12: - mc = 1; - nc = 2; - gemm<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm<1, 1>(m0, m, n0, n); - break; - default: - return; - } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; - mnpack(mp, m, n0, np); - mnpack(m0, m, np, n); - } - - template <int RM, int RN> - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) - end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; - D Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; l += KN) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l), - load<V>(B + ldb * (jj + j) + l), - Cv[j][i]); - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); - } - } - - const TA *const A; - const TB *const B; - TC *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; - -////////////////////////////////////////////////////////////////////////////////////////// -// QUANT ZERO MATRIX MULTIPLICATION - -#if defined(__ARM_FEATURE_DOTPROD) -template <typename TA> -class tinyBLAS_Q0_ARM { - public: - tinyBLAS_Q0_ARM(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth) - : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { - } - - void matmul(int64_t m, int64_t n, int task) { - if (task == GGML_TASK_TYPE_COMPUTE) - mnpack(0, m, 0, n); - } - - private: - NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) { - case 0x33: - mc = 3; - nc = 3; - gemm<3, 3>(m0, m, n0, n); - break; - case 0x32: - mc = 3; - nc = 2; - gemm<3, 2>(m0, m, n0, n); - break; - case 0x23: - mc = 2; - nc = 3; - gemm<2, 3>(m0, m, n0, n); - break; - case 0x22: - mc = 2; - nc = 2; - gemm<2, 2>(m0, m, n0, n); - break; - case 0x31: - mc = 3; - nc = 1; - gemm<3, 1>(m0, m, n0, n); - break; - case 0x13: - mc = 1; - nc = 3; - gemm<1, 3>(m0, m, n0, n); - break; - case 0x21: - mc = 2; - nc = 1; - gemm<2, 1>(m0, m, n0, n); - break; - case 0x12: - mc = 1; - nc = 2; - gemm<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm<1, 1>(m0, m, n0, n); - break; - default: - return; - } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; - mnpack(mp, m, n0, np); - mnpack(m0, m, np, n); - } - - template <int RM, int RN> - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) - end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; - float32x4_t Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; ++l) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - Cv[j][i] = vmlaq_n_f32(Cv[j][i], - vcvtq_f32_s32(vdotq_s32( - vdotq_s32(vdupq_n_s32(0), - load_lo(A + lda * (ii + i) + l), - load_lo(B + ldb * (jj + j) + l)), - load_hi(A + lda * (ii + i) + l), - load_hi(B + ldb * (jj + j) + l))), - unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)); - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); - } - } - - inline int8x16_t load_lo(const block_q8_0 *b) { - return vld1q_s8(b->qs); - } - - inline int8x16_t load_hi(const block_q8_0 *b) { - return vld1q_s8(b->qs + 16); - } - - inline int8x16_t load_lo(const block_q4_0 *b) { - return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), - vdupq_n_u8(0x0f))), - vdupq_n_s8(0x8)); - } - - inline int8x16_t load_hi(const block_q4_0 *b) { - return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), - vdupq_n_s8(0x8)); - } - - const TA *const A; - const block_q8_0 *const B; - float *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; -#endif // __ARM_FEATURE_DOTPROD - -#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) -template <typename TA, typename TB, typename TC> -class tinyBLAS_Q0_AVX { - public: - tinyBLAS_Q0_AVX(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, - int ith, int nth) - : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { - } - - void matmul(int64_t m, int64_t n, int task) { - if (task == GGML_TASK_TYPE_COMPUTE) - mnpack(0, m, 0, n); - } - - private: - void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { -#if VECTOR_REGISTERS == 32 - case 0x44: - mc = 4; - nc = 4; - gemm<4, 4>(m0, m, n0, n); - break; - case 0x43: - mc = 4; - nc = 3; - gemm<4, 3>(m0, m, n0, n); - break; - case 0x34: - mc = 3; - nc = 4; - gemm<3, 4>(m0, m, n0, n); - break; - case 0x33: - mc = 3; - nc = 3; - gemm<3, 3>(m0, m, n0, n); - break; - case 0x42: - mc = 4; - nc = 2; - gemm<4, 2>(m0, m, n0, n); - break; - case 0x24: - mc = 2; - nc = 4; - gemm<2, 4>(m0, m, n0, n); - break; -#else - case 0x44: - case 0x43: - case 0x42: - mc = 4; - nc = 2; - gemm<4, 2>(m0, m, n0, n); - break; - case 0x34: - case 0x24: - mc = 2; - nc = 4; - gemm<2, 4>(m0, m, n0, n); - break; - case 0x33: -#endif - case 0x32: - mc = 3; - nc = 2; - gemm<3, 2>(m0, m, n0, n); - break; - case 0x23: - mc = 2; - nc = 3; - gemm<2, 3>(m0, m, n0, n); - break; - case 0x41: - mc = 4; - nc = 1; - gemm<4, 1>(m0, m, n0, n); - break; - case 0x22: - mc = 2; - nc = 2; - gemm<2, 2>(m0, m, n0, n); - break; - case 0x14: - mc = 1; - nc = 4; - gemm<1, 4>(m0, m, n0, n); - break; - case 0x31: - mc = 3; - nc = 1; - gemm<3, 1>(m0, m, n0, n); - break; - case 0x13: - mc = 1; - nc = 3; - gemm<1, 3>(m0, m, n0, n); - break; - case 0x21: - mc = 2; - nc = 1; - gemm<2, 1>(m0, m, n0, n); - break; - case 0x12: - mc = 1; - nc = 2; - gemm<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm<1, 1>(m0, m, n0, n); - break; - default: - return; - } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; - mnpack(mp, m, n0, np); - mnpack(m0, m, np, n); - } - - template <int RM, int RN> - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) - end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; - __m256 Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; ++l) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) { -#if defined(__AVX2__) - __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))); -#else - __m128i ali0 = load0(A + lda * (ii + i) + l); - __m128i ali1 = load1(A + lda * (ii + i) + l); - __m128i blj0 = load0(B + ldb * (jj + j) + l); - __m128i blj1 = load1(B + ldb * (jj + j) + l); - - __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); - __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); - __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); - __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); - - // updot - const __m128i oneFill = _mm_set1_epi16(1); - __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); - __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); - __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); -#endif - Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)), - udTmp, - Cv[j][i]); - } - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); - } - } - - inline __m256i load(const block_q8_0 *b) { - return _mm256_loadu_si256((const __m256i *)b->qs); - } - - inline __m128i load0(const block_q8_0 *b) { - return _mm_loadu_si128((const __m128i *)b->qs); - } - - inline __m128i load1(const block_q8_0 *b) { - return _mm_loadu_si128(((const __m128i *)b->qs) + 1); - } - - inline __m256i load(const block_q4_0 *b) { - return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); - } - - inline __m128i load0(const block_q4_0 *b) { - const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); - return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); - } - - inline __m128i load1(const block_q4_0 *b) { - const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); - return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); - } - - inline __m256 updot(__m256i u, __m256i s) { - __m256i res; -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) - res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); -#else - res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); -#endif - return _mm256_cvtepi32_ps(res); - } - - static inline __m256i denibble(const uint8_t *p) { - __m128i x = _mm_loadu_si128((const __m128i *)p); - return _mm256_and_si256(_mm256_set1_epi8(15), - _mm256_insertf128_si256(_mm256_castsi128_si256(x), - _mm_srli_epi16(x, 4), 1)); - } - - const TA *const A; - const TB *const B; - TC *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; -#endif // __AVX__ - -} // namespace - -/** - * Performs optimized matrix multiplication on CPU. - * - * This subroutine may compute C = Aᵀ * B with column major ordering. - * Despite its name, this isn't a generalized implementation. Work is - * only performed when a handwritten kernel is written and available. - * Otherwise the caller should fall back to a general matmul routine. - * - * For example, for single-threaded single-precision GEMM you can say - * - * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, - * 0, 1, GGML_TASK_TYPE_COMPUTE, - * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32); - * - * @param m is rows in `A` and `C` - * @param n is cols in `B` and `C` - * @param k is cols in `A` and rows in `B` - * @param A is first input matrix (always transposed) - * @param lda is row stride of `A` - * @param B is second input matrix (never transposed) - * @param ldb is row stride of `B` - * @param C is input/output array of output matrices - * @param ldc is row stride of `C` - * @param ith is thread id (must be less than `nth`) - * @param nth is number of threads (must be greater than zero) - * @param task is GGML task type - * @param Atype is GGML data type of `A` - * @param Btype is GGML data type of `B` - * @param Ctype is GGML data type of `C` - * @return true if this function was able to service the matmul request - */ - -bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, - int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { - - assert(m >= 0); - assert(n >= 0); - assert(k >= 0); - assert(lda >= k); - assert(ldb >= k); - assert(ldc >= m); - assert(nth > 0); - assert(ith < nth); - - if (Ctype != GGML_TYPE_F32) - return false; - - switch (Atype) { - - case GGML_TYPE_F32: { - if (Btype != GGML_TYPE_F32) - return false; -#if defined(__AVX512F__) - if (k % 16) - return false; - tinyBLAS<16, __m512, __m512, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n, task); - return true; -#elif defined(__AVX__) || defined(__AVX2__) - if (k % 8) - return false; - tinyBLAS<8, __m256, __m256, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n, task); - return true; -#elif defined(__ARM_NEON) - if (n < 4) - return false; - if (k % 4) - return false; - tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n, task); - return true; -#else - return false; -#endif - } - - case GGML_TYPE_F16: { -#if defined(__AVX512F__) - if (k % 16) - return false; - if (Btype != GGML_TYPE_F32) - return false; - tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n, task); - return true; -#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) - if (k % 8) - return false; - if (Btype != GGML_TYPE_F32) - return false; - tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n, task); - return true; -#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) - if (n < 8) - return false; - if (k % 8) - return false; - if (Btype != GGML_TYPE_F16) - return false; - tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n, task); - return true; -#elif defined(__ARM_NEON) && !defined(_MSC_VER) - if (k % 4) - return false; - if (Btype != GGML_TYPE_F32) - return false; - tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n, task); - return true; -#else - return false; -#endif - } - - case GGML_TYPE_Q8_0: { - if (Btype != GGML_TYPE_Q8_0) - return false; -#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) - tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{ - k, (const block_q8_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n, task); - return true; -#elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM<block_q8_0> tb{ - k, (const block_q8_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n, task); - return true; -#else - return false; -#endif - } - - case GGML_TYPE_Q4_0: { - if (Btype != GGML_TYPE_Q8_0) - return false; -#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) - tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{ - k, (const block_q4_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n, task); - return true; -#elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM<block_q4_0> tb{ - k, (const block_q4_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n, task); - return true; -#else - return false; -#endif - } - - default: - return false; - } - - (void)m; - (void)n; - (void)k; - (void)A; - (void)lda; - (void)B; - (void)ldb; - (void)C; - (void)ldc; - (void)ith; - (void)nth; - (void)task; - (void)Atype; - (void)Btype; - (void)Ctype; -} |