summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp22
1 files changed, 7 insertions, 15 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 69c0b9a4..a9adedfc 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -2917,15 +2917,16 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
template <int nrc_y>
static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
+ GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
-#endif
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
auto values = MM256_SET_M128I(values128, values128);
- //auto values = load_iq4nl_values_256();
+#else
+ auto values = load_iq4nl_values_256();
+#endif
int nbl = n / QK_K;
using helper_t = union { __m256i vec; uint32_t val[8]; };
helper_t h;
@@ -2969,7 +2970,7 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
float d8 = q8.scale(iy, ibl);
- float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]);
+ float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
#else
@@ -2979,15 +2980,6 @@ static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const
_mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])));
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
- //auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00))),
- // _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))));
- //auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa))),
- // _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))));
- //auto sumi = _mm256_add_epi32(sumi1, sumi2);
- //float d8 = q8.scale(iy, ibl);
- //float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]);
- //acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
- //acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
#endif
}
}
@@ -3057,7 +3049,7 @@ static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
float d8 = q8.scale(iy, ibl);
- float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]);
+ float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, _mm512_set1_ps(d8)), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[2*iy+1]);
}
@@ -5074,7 +5066,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_iq4_xs_r4_q8_k<6>;
mm.funcs[6] = mul_mat_iq4_xs_r4_q8_k<7>;
mm.funcs[7] = mul_mat_iq4_xs_r4_q8_k<8>;
- expected_typeB = GGML_TYPE_Q8_K;
+ expected_typeB = GGML_TYPE_Q8_K32;
break;
case GGML_TYPE_Q4_0_R4:
assert (ne00 % QK4_NL == 0);