// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp fenc=utf-8 :vi // // // Copyright (C) 2024 Iwan Kawrakow // MIT license // SPDX-License-Identifier: MIT // #pragma once #include "iqk_config.h" #if defined IQK_IMPLEMENT #include #include #include #include #include "ggml-impl.h" #include "ggml-quants.h" #include "iqk_mul_mat.h" #include "iqk_quantize.h" #define GGML_COMMON_IMPL_C #include "ggml-common.h" #define FA_TIMING 0 #include #include #if FA_TIMING #include #include struct Perf { using TimePoint = std::chrono::time_point; std::array times = {}; std::mutex mutex; bool report; static auto cur_time() { return std::chrono::high_resolution_clock::now(); } inline void accum(int what, const TimePoint& t1) { auto t2 = cur_time(); auto dt = delta(t1, t2); std::lock_guard lock(mutex); times[what] += dt; } inline void accum_nolock(int what, const TimePoint& t1) { auto t2 = cur_time(); auto dt = delta(t1, t2); times[what] += dt; } inline void add(const Perf& other) { std::lock_guard lock(mutex); for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i]; } Perf(bool r) : report(r) {} ~Perf() { if (report) { double tot = 0; for (auto& t : times) tot += t; if (!tot) return; printf("======================= Timing: %g ms in total\n", tot); for (int i = 0; i < int(times.size()); ++i) { if (times[i]) { printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%'); } } } } static Perf& instance() { static Perf p(true); return p; } static double delta(const TimePoint& t1, const TimePoint& t2) { return 1e-6*std::chrono::duration_cast(t2-t1).count(); } }; #endif #ifdef __AVX2__ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) #endif typedef struct { int32_t i1; int32_t i2; } mmid_row_mapping; struct DataInfo { float * s; const char * cy; size_t bs; size_t by; int cur_y = 0; int ne11; const mmid_row_mapping * row_mapping = nullptr; size_t bs2 = 0; inline const char * src1_row(int iy) const { if (!row_mapping) return cy + (cur_y + iy)*by; int i11 = row_mapping[cur_y + iy].i1 % ne11; int i12 = row_mapping[cur_y + iy].i2; return cy + (i11 + i12*ne11)*by; } inline void store(int ix, int iy, float result) const { *(dst_row(iy) + ix) = result; } #ifdef __AVX__ inline void store(int ix, int iy, __m128 result) const { _mm_storeu_ps(dst_row(iy) + ix, result); } inline void store(int ix, int iy, __m256 result) const { _mm256_storeu_ps(dst_row(iy) + ix, result); } #endif #ifdef __AVX512F__ inline void store(int ix, int iy, __m512 result) const { _mm512_storeu_ps(dst_row(iy) + ix, result); } #endif #ifdef __ARM_NEON inline void store(int ix, int iy, float32x4_t result) const { vst1q_f32(dst_row(iy) + ix, result); } #endif inline float * dst_row(int iy) const { if (!row_mapping) return s + (cur_y + iy)*bs; int i12 = row_mapping[cur_y + iy].i2; int i1 = row_mapping[cur_y + iy].i1; int i2 = i12; return s + i1*bs + i2*bs2; } }; typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x); #define IQK_MAX_NY 8 #define IQK_SET_MUL_MAT_FUNCTIONS_T(kernel, Dequantizer, funcs) \ funcs[0] = kernel;\ funcs[1] = kernel;\ funcs[2] = kernel;\ funcs[3] = kernel;\ funcs[4] = kernel;\ funcs[5] = kernel;\ funcs[6] = kernel;\ funcs[7] = kernel;\ #define IQK_SET_MUL_MAT_FUNCTIONS(kernel, funcs) \ funcs[0] = kernel<1>;\ funcs[1] = kernel<2>;\ funcs[2] = kernel<3>;\ funcs[3] = kernel<4>;\ funcs[4] = kernel<5>;\ funcs[5] = kernel<6>;\ funcs[6] = kernel<7>;\ funcs[7] = kernel<8>;\ // ================================================================================================== static inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { const uint16_t * scales = (const uint16_t *)scales8; const uint32_t a0 = scales[0] | (scales[1] << 16); const uint32_t a1 = scales[2] | (scales[3] << 16); const uint32_t a2 = scales[4] | (scales[5] << 16); aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030); aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030); aux32[2] = a1 & 0x3f3f3f3f; aux32[0] = a0 & 0x3f3f3f3f; } const uint64_t keven_signs[128] = { 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff, 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff, 0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff, 0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff, 0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff, 0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff, 0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff, 0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff, 0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff, 0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff, 0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff, 0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff, 0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff, 0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff, 0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff, 0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff, 0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff, 0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff, 0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff, 0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff, 0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff, 0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff, 0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff, 0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff, 0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff, 0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff, 0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff, 0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff, 0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff, 0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff, 0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff, 0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff, }; #ifdef __AVX2__ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) static inline float hsum_float_4(__m128 x) { x = _mm_add_ps(x, _mm_movehl_ps(x, x)); x = _mm_add_ss(x, _mm_movehdup_ps(x)); return _mm_cvtss_f32(x); } static inline float hsum_float_8(__m256 x) { return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); } static inline int hsum_i32_8(const __m256i a) { const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); const __m128i sum64 = _mm_add_epi32(hi64, sum128); const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); } static inline float hmax_f32_8(__m256 x) { __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); return _mm_cvtss_f32(max4); } static inline float hmax_float_8(__m256 x) { __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4)); max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4)); return _mm_cvtss_f32(max4); } static inline __m128 hsum_float_4x4(__m128 * accm) { accm[0] = _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[2]), _mm_unpackhi_ps(accm[0], accm[2])); accm[1] = _mm_add_ps(_mm_unpacklo_ps(accm[1], accm[3]), _mm_unpackhi_ps(accm[1], accm[3])); return _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[1]), _mm_unpackhi_ps(accm[0], accm[1])); } static inline __m256 hsum_float_8x8(__m256 * accm) { for (int i = 0; i < 4; ++i) { accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i + 4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i + 4], 0x31)); //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); } for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i + 2]), _mm256_unpackhi_ps(accm[i], accm[i + 2])); return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); } static inline __m256 hsum_float_4x8(__m256 * accm) { for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i + 2]), _mm256_unpackhi_ps(accm[i], accm[i + 2])); return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); } static inline __m128i load_iq4nl_values_128() { static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; return _mm_loadu_si128((const __m128i *)kvalues_iq4nl); } static inline __m256i load_iq4nl_values_256() { auto val128 = load_iq4nl_values_128(); return MM256_SET_M128I(val128, val128); } #ifdef HAVE_FANCY_SIMD static inline __m512i load_iq4nl_values_512() { auto val256 = load_iq4nl_values_256(); return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); } #endif static inline __m128i load_iq4k_values_128() { return _mm_loadu_si128((const __m128i *)iq4k_values); } static inline __m256i load_iq4k_values_256() { auto val128 = load_iq4k_values_128(); return MM256_SET_M128I(val128, val128); } template struct Q8 { constexpr static int nrc_y = nrc; Q8(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); } #ifdef HAVE_FANCY_SIMD inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); } #endif inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); } inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); } inline float scale(int iy, int i) const { return y[iy][i].d; } const block_q8 * y[nrc_y]; }; template struct Q8_16 { constexpr static int nrc_y = nrc; Q8_16(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) { auto ptr = (const float *)info.src1_row(iy); std::memcpy(d + 5*iy, ptr, 5*sizeof(float)); y[iy] = (const int8_t *)(ptr + 5); } } #ifdef HAVE_FANCY_SIMD inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); } #endif inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); } inline float scale(int iy, int k) const { return d[5*iy+k]; } inline float sum_row(int iy) const { return d[5*iy + 4]; } inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); } float d[5*nrc_y]; const int8_t * y[nrc_y]; }; struct Scales8KBase { template inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const { const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0])); for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i q8s = q8.load_bsums(iy, i); const __m256i prod = _mm256_madd_epi16(mins, q8s); accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]); } } inline __m256i shuffle(__m128i mins) const { return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0])); } const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100), _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)}; }; template struct BaseDequantizer { BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {} inline void new_row(int ix) { if constexpr (per_row_scale) { if constexpr (is_f16) { const ggml_half * dptr = (const ggml_half *)((const char *)vx + bx*ix); d = GGML_FP16_TO_FP32(*dptr); x = (const Block *)(dptr + 1); } else { const float * dptr = (const float *)((const char *)vx + bx*ix); d = *dptr; x = (const Block *)(dptr + 1); } } else { x = (const Block *)((const char *)vx + bx*ix); } } const void * vx; const size_t bx; const Block * x; float d; }; template static inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { if (j == 0) { #ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); } #else for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4)); } #endif } else { #ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); } #else for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3)); sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4)); } #endif } } template static inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { __m256i p[4]; if (j == 0) { for (int iy = 0; iy < Q8::nrc_y; ++iy) { for (int k = 0; k < 4; ++k) { auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]); p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k]))); } sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3])); } } else { for (int iy = 0; iy < Q8::nrc_y; ++iy) { for (int k = 0; k < 4; ++k) { auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]); p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k]))); } sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2])); sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3])); } } } #ifdef HAVE_FANCY_SIMD struct BlockPermuter { const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0); const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4); }; struct Q4Bits { inline void prepare(const uint8_t * q4) { auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0); auto tmp1 = _mm512_and_si512(q4bits, ml); auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2); values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2); q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1); tmp1 = _mm512_and_si512(q4bits, ml); tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2); values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2); } inline void prepare64(const uint8_t * q4) { auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0); values[0] = _mm512_and_si512(q4bits, ml); values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1); values[2] = _mm512_and_si512(q4bits, ml); values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); } inline void prepare64a(const uint8_t * q4) { for (int k = 0; k < 4; ++k) { auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k); values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1); values[k] = _mm512_and_si512(values[k], ml); } } __m512i values[4]; const __m512i ml = _mm512_set1_epi8(0xf); const BlockPermuter perm; }; struct Q2Bits { inline void prepare(const uint8_t * q2) { auto q2bits = _mm512_loadu_si512((const __m512i*)q2); auto tmp = _mm512_srli_epi16(q2bits, 2); values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp); values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp); values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml); values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml); values[0] = _mm512_and_si512(values[0], ml); values[2] = _mm512_and_si512(values[2], ml); } __m512i values[4]; const __m512i ml = _mm512_set1_epi8(0x03); BlockPermuter perm; }; #else struct Q2Bits { inline void prepare(const uint8_t * q2, int j) { auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j); values[0] = _mm256_and_si256(q2bits, ml); values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml); values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml); values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml); } __m256i values[4]; const __m256i ml = _mm256_set1_epi8(0x03); }; struct Q4Bits { inline void prepare(const uint8_t * q4, int j) { auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0); values[0] = _mm256_and_si256(q4bits, ml); values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1); values[2] = _mm256_and_si256(q4bits, ml); values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); } inline void prepare64(const uint8_t * q4, int j) { auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0); values[0] = _mm256_and_si256(q4bits, ml); values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1); values[1] = _mm256_and_si256(q4bits, ml); values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); } inline void prepare16(const uint8_t * q4, int j) { values[0] = dequant16(q4 + 64*j + 0); values[1] = dequant16(q4 + 64*j + 16); values[2] = dequant16(q4 + 64*j + 32); values[3] = dequant16(q4 + 64*j + 48); } inline __m256i dequant16(const uint8_t * qs) const { const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs); const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128); return _mm256_and_si256(ml, aux256); } __m256i values[4]; const __m256i ml = _mm256_set1_epi8(0xf); }; #endif inline void iqk_transpose_8x8(__m256 * m) { for (int k = 0; k < 8; k += 4) { auto t0 = _mm256_unpacklo_ps(m[k+0], m[k+1]); auto t1 = _mm256_unpacklo_ps(m[k+2], m[k+3]); auto t2 = _mm256_unpackhi_ps(m[k+0], m[k+1]); auto t3 = _mm256_unpackhi_ps(m[k+2], m[k+3]); m[k+0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); m[k+1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); m[k+2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); m[k+3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); } for (int k = 0; k < 4; ++k) { auto t = _mm256_set_m128(_mm256_extractf128_ps(m[k+4], 1), _mm256_extractf128_ps(m[k], 1)); m[k+0] = _mm256_set_m128(_mm256_castps256_ps128(m[k+4]), _mm256_castps256_ps128(m[k+0])); m[k+4] = t; } } #else // ------------------------------------ __aarch64__ -------------------------------------------------- template struct Q8 { constexpr static int nrc_y = nrc; Q8(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); } inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); } inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); } inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); } inline int16x8_t load_bsums8(int iy, int i) const { auto q8s = vld1q_s16_x2(y[iy][i].bsums); return vpaddq_s16(q8s.val[0], q8s.val[1]); } inline float scale(int iy, int i) const { return y[iy][i].d; } const block_q8 * y[nrc_y]; }; template struct BaseDequantizer { BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {} inline void new_row(int ix) { if constexpr (has_row_scale) { if constexpr (scale_is_f16) { const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx); d = GGML_FP16_TO_FP32(*dptr); x = (const block_q *)(dptr + 1); } else { const float * dptr = (const float *)((const char *)vx + ix*bx); d = *dptr; x = (const block_q *)(dptr + 1); } } else { x = (const block_q *)((const char *)vx + ix*bx); } } const void * vx; const block_q * x; const size_t bx; const int nrc; float d; }; struct Q4bits { const uint8x16_t m4b = vdupq_n_u8(0xf); uint8x16x4_t b1, b2; inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const { b.val[0] = vandq_u8(val[0], m4b); b.val[2] = vshrq_n_u8(val[0], 4); b.val[1] = vandq_u8(val[1], m4b); b.val[3] = vshrq_n_u8(val[1], 4); } inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const { b.val[0] = vandq_u8(val[0], m4b); b.val[1] = vshrq_n_u8(val[0], 4); b.val[2] = vandq_u8(val[1], m4b); b.val[3] = vshrq_n_u8(val[1], 4); } inline void prepare(const uint8_t * qs) { auto q4bits = vld1q_u8_x2(qs); prepare4(b1, q4bits.val); q4bits = vld1q_u8_x2(qs+32); prepare4(b2, q4bits.val); } inline void prepare_v2(const uint8_t * qs) { auto q4bits = vld1q_u8_x4(qs); prepare4(b1, q4bits.val+0); prepare4(b2, q4bits.val+2); } inline void prepare64(const uint8_t * qs) { auto q4bits = vld1q_u8_x4(qs); b1.val[0] = vandq_u8(q4bits.val[0], m4b); b1.val[1] = vandq_u8(q4bits.val[1], m4b); b1.val[2] = vandq_u8(q4bits.val[2], m4b); b1.val[3] = vandq_u8(q4bits.val[3], m4b); b2.val[0] = vshrq_n_u8(q4bits.val[0], 4); b2.val[1] = vshrq_n_u8(q4bits.val[1], 4); b2.val[2] = vshrq_n_u8(q4bits.val[2], 4); b2.val[3] = vshrq_n_u8(q4bits.val[3], 4); } inline void prepare16(const uint8_t * qs) { auto q4bits = vld1q_u8_x2(qs); prepare4_16(b1, q4bits.val); q4bits = vld1q_u8_x2(qs+32); prepare4_16(b2, q4bits.val); } inline void prepare16_v2(const uint8_t * qs) { auto q4bits = vld1q_u8_x4(qs); prepare4_16(b1, q4bits.val+0); prepare4_16(b2, q4bits.val+2); } }; struct Q2bits { const uint8x16_t m4b = vdupq_n_u8(0x03); uint8x16x4_t b1, b2; inline void prepare(const uint8_t * qs) { auto q2bits = vld1q_u8_x2(qs); b1.val[0] = vandq_u8(q2bits.val[0], m4b); b1.val[1] = vandq_u8(q2bits.val[1], m4b); q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); b1.val[2] = vandq_u8(q2bits.val[0], m4b); b1.val[3] = vandq_u8(q2bits.val[1], m4b); q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); b2.val[0] = vandq_u8(q2bits.val[0], m4b); b2.val[1] = vandq_u8(q2bits.val[1], m4b); q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); b2.val[2] = vandq_u8(q2bits.val[0], m4b); b2.val[3] = vandq_u8(q2bits.val[1], m4b); } }; template static inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) { auto mzero = vdupq_n_s32(0); auto q8b_1 = q8.load_quants(iy, i, 4*j+0); auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]), vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1 auto q8b_2 = q8.load_quants(iy, i, 4*j+1); auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]), vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2 auto p12 = vpaddq_s32(p1, p2); auto q8b_3 = q8.load_quants(iy, i, 4*j+2); auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]), vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1 auto q8b_4 = q8.load_quants(iy, i, 4*j+3); auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]), vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2 auto p34 = vpaddq_s32(p3, p4); auto pall = vpaddq_s32(p12, p34); sumi = vmlaq_s32(sumi, scales.val[j], pall); } template static inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) { auto mzero = vdupq_n_s32(0); auto q8b_1 = q8.load_quants(iy, i, 4*j+0); auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1, auto q8b_2 = q8.load_quants(iy, i, 4*j+1); auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4, auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3 sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12); auto q8b_3 = q8.load_quants(iy, i, 4*j+2); auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5, auto q8b_4 = q8.load_quants(iy, i, 4*j+3); auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7, auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7 sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34); } struct SignHelper { inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); } inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) { auto aux = vqtbl1q_u8(signs16, shuffle); auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); shuffle = vaddq_u8(shuffle, step); } const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); const uint8x16_t m1 = vdupq_n_u8(1); const uint8x16_t step = vdupq_n_u8(2); uint8x16_t shuffle; }; template static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx, nrc_y); for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); float32x4_t acc[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); for (int i = 0; i < nb; ++i) { int32x4_t sumi[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) { deq.process_scales(i, q8, acc); deq.prepare(i, 0); deq.compute(q8, i, 0, sumi); deq.prepare(i, 1); deq.compute(q8, i, 1, sumi); } else { if constexpr (Dequantizer::num_blocks() == 8) { auto scales = deq.new_block(i, q8, acc); deq.prepare(i, 0); for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else if constexpr (Dequantizer::num_blocks() == 16) { auto scales = deq.new_block(i, q8, acc); deq.prepare(i, 0); for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else { GGML_ASSERT(false); } } for (int iy = 0; iy < nrc_y; ++iy) { acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); } } } static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) { auto sumi = vdupq_n_s32(0); sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); return sumi; } static IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) { int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) }; sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0); sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0); sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1); sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1); sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2); sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2); sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3); sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3); return sumi; } static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) { auto sumi = vdupq_n_s32(0); sumi = vdotq_laneq_s32(sumi, qx[0], y, 0); sumi = vdotq_laneq_s32(sumi, qx[1], y, 1); sumi = vdotq_laneq_s32(sumi, qx[2], y, 2); sumi = vdotq_laneq_s32(sumi, qx[3], y, 3); return sumi; } static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) { qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19 qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7 qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23 qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11 qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27 qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15 qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 } static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) { qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4)); qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4)); qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); } #endif #endif