summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_gemm_legacy_quants.cpp')
-rw-r--r--ggml/src/iqk/iqk_gemm_legacy_quants.cpp78
1 files changed, 78 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
index 6e262aab..17d2dad3 100644
--- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
+++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
@@ -1615,6 +1615,81 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
}
#endif
+typedef struct {
+ ggml_half d[16];
+ uint8_t qs[256];
+} block_q8_1_r8;
+
+template <int nrc_y>
+static void mul_mat_q8_1_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_2_x4> q8(info);
+ int nb = n / QK8_0;
+ __m256 acc[nrc_y] = {};
+ float d8[4*nrc_y];
+ __m256i qx[4];
+ auto dot = [&qx] (const int8_t * qy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)qy);
+ auto y = MM256_SET_M128I(y128, y128);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ return sumi;
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ return _mm256_add_epi32(_mm256_madd_epi16(_mm256_set1_epi16(1), sumi1), _mm256_madd_epi16(_mm256_set1_epi16(1), sumi2));
+#endif
+ };
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_1_r8 * iq8 = (const block_q8_1_r8 *)((const char *)vx + ix*bx);
+ for (int i4 = 0; i4 < nb/4; ++i4) {
+ {
+ __m256 mx[4];
+ for (int ib32 = 0; ib32 < 4; ++ib32) mx[ib32] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*i4+ib32].d+1));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][i4].d)), 16));
+ _mm_storeu_ps(d8 + 4*iy + 0, scales);
+ auto bsums4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][i4].d+4))), 16));
+ auto bsums = _mm256_set_m128(bsums4, bsums4);
+ acc[iy] = _mm256_fmadd_ps(mx[0], _mm256_shuffle_ps(bsums, bsums, 0x00), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(mx[1], _mm256_shuffle_ps(bsums, bsums, 0x55), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(mx[2], _mm256_shuffle_ps(bsums, bsums, 0xaa), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(mx[3], _mm256_shuffle_ps(bsums, bsums, 0xff), acc[iy]);
+ }
+ }
+ for (int ib32 = 0; ib32 < 4; ++ib32) {
+ auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*i4+ib32].d));
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*i4+ib32].qs+j);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(q8.y[iy][i4].qs+32*ib32);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+ib32]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*i4+ib32].qs+4+j);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = dot(q8.y[iy][i4].qs+32*ib32+16);
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+ib32]));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
@@ -1694,6 +1769,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
case GGML_TYPE_IQ4_NL_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_nl_r4_q8_2, kernels)
break;
+ case GGML_TYPE_Q8_1: // Note: we are misusing the Q8_1 type for Q8_1_R8
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_1_r8_q8_2, kernels)
+ break;
default:
return false;
}