diff options
Diffstat (limited to 'ggml/src/iqk/iqk_gemm_legacy_quants.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_gemm_legacy_quants.cpp | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index 6e262aab..17d2dad3 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -1615,6 +1615,81 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn } #endif +typedef struct { + ggml_half d[16]; + uint8_t qs[256]; +} block_q8_1_r8; + +template <int nrc_y> +static void mul_mat_q8_1_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8<nrc_y, block_q8_2_x4> q8(info); + int nb = n / QK8_0; + __m256 acc[nrc_y] = {}; + float d8[4*nrc_y]; + __m256i qx[4]; + auto dot = [&qx] (const int8_t * qy) { + auto y128 = _mm_loadu_si128((const __m128i*)qy); + auto y = MM256_SET_M128I(y128, y128); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + return sumi; +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + return _mm256_add_epi32(_mm256_madd_epi16(_mm256_set1_epi16(1), sumi1), _mm256_madd_epi16(_mm256_set1_epi16(1), sumi2)); +#endif + }; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_1_r8 * iq8 = (const block_q8_1_r8 *)((const char *)vx + ix*bx); + for (int i4 = 0; i4 < nb/4; ++i4) { + { + __m256 mx[4]; + for (int ib32 = 0; ib32 < 4; ++ib32) mx[ib32] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*i4+ib32].d+1)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][i4].d)), 16)); + _mm_storeu_ps(d8 + 4*iy + 0, scales); + auto bsums4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][i4].d+4))), 16)); + auto bsums = _mm256_set_m128(bsums4, bsums4); + acc[iy] = _mm256_fmadd_ps(mx[0], _mm256_shuffle_ps(bsums, bsums, 0x00), acc[iy]); + acc[iy] = _mm256_fmadd_ps(mx[1], _mm256_shuffle_ps(bsums, bsums, 0x55), acc[iy]); + acc[iy] = _mm256_fmadd_ps(mx[2], _mm256_shuffle_ps(bsums, bsums, 0xaa), acc[iy]); + acc[iy] = _mm256_fmadd_ps(mx[3], _mm256_shuffle_ps(bsums, bsums, 0xff), acc[iy]); + } + } + for (int ib32 = 0; ib32 < 4; ++ib32) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*i4+ib32].d)); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*i4+ib32].qs+j); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(q8.y[iy][i4].qs+32*ib32); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+ib32])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*i4+ib32].qs+4+j); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(q8.y[iy][i4].qs+32*ib32+16); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+ib32])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = _mm256_setzero_ps(); + } + } +} + template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) { if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> || std::is_same_v<Dequantizer, Q8_0_Unpacker>) { @@ -1694,6 +1769,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu case GGML_TYPE_IQ4_NL_R4: IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_nl_r4_q8_2, kernels) break; + case GGML_TYPE_Q8_1: // Note: we are misusing the Q8_1 type for Q8_1_R8 + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_1_r8_q8_2, kernels) + break; default: return false; } |