diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 75 |
1 files changed, 37 insertions, 38 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index c561ca2b..aeba2c59 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3528,26 +3528,27 @@ static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const D template <int nrc_y> static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8<nrc_y, block_q8_1_x4> q8(info); + Q8<nrc_y, block_q8_K128> q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); __m256i qx[4]; __m256 acc[nrc_y] = {}; auto m1 = _mm256_set1_epi16(1); auto ms = _mm_set1_epi16(-32768); - float d8[8*nrc_y]; + float d8[4*nrc_y]; union { __m256i vec; uint16_t val[16]; } helper; struct aux_iq1_s_r4 { uint8_t qs[16]; uint64_t qh; }; - for (int ix= 0; ix < nrc_x; ix += 4) { + for (int ix = 0; ix < nrc_x; ix += 4) { auto dptr = (const ggml_half *)((const char *)vx + ix*bx); auto d1 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr)); auto x = (const aux_iq1_s_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8 + 8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib].d))); + auto bsums = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].bsums)); + _mm_storeu_ps(d8 + 4*iy, _mm_mul_ps(_mm_set1_ps(q8.y[iy][ib].d), _mm_cvtepi32_ps(bsums))); } for (int k = 0; k < 4; ++k) { auto idxh = _mm256_set1_epi64x(x[4*ib+k].qh); @@ -3556,8 +3557,8 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI scales4 = _mm_or_si128(_mm_slli_epi16(scales4, 1), _mm_set1_epi16(1)); auto signs = _mm_or_si128(_mm_cmpeq_epi16(_mm_and_si128(sas, ms), ms), _mm256_castsi256_si128(m1)); signs = _mm_add_epi16(_mm_set1_epi16(-8), signs); - auto delta4 = _mm_mul_ps(_mm_set1_ps(0.0625f), _mm_cvtepi32_ps(_mm_cvtepi16_epi32( - _mm_mullo_epi16(scales4, signs)))); + signs = _mm_mullo_epi16(signs, scales4); + auto delta4 = _mm_mul_ps(_mm_set1_ps(0.0625f), _mm_cvtepi32_ps(_mm_cvtepi16_epi32(signs))); auto delta = _mm256_set_m128(delta4, delta4); scales4 = _mm_unpacklo_epi16(scales4, scales4); // 0,0, 1,1, 2,2, 3,3 auto scales = MM256_SET_M128I(scales4, scales4); @@ -3598,8 +3599,8 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI auto sumi = _mm256_packs_epi32(sumi1, sumi2); #endif sumi = _mm256_madd_epi16(scales, sumi); - acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[8*iy+k+0]), _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[8*iy+k+4]), delta, acc[iy]); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.y[iy][ib].d), _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[4*iy+k]), delta, acc[iy]); } } } @@ -3614,7 +3615,7 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI template <int nrc_y> static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8<nrc_y, block_q8_0_x4> q8(info); + Q8<nrc_y, block_q8_K128> q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); auto shuffle0 = _mm256_set_epi64x(0x0909090909090909, 0x0808080808080808, 0x0101010101010101, 0x0000000000000000); @@ -3624,17 +3625,14 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI #endif __m256i qx[4]; __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; auto ms = _mm_set1_epi8(0x08); - float d8[4*nrc_y]; union { __m256i vec; uint16_t val[16]; } helper; for (int ix= 0; ix < nrc_x; ix += 4) { auto dptr = (const ggml_half *)((const char *)vx + ix*bx); auto d1 = _mm_mul_ps(_mm_set1_ps(0.125f), _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr))); auto x = (const block_iq1_m_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { - for (int iy = 0; iy < nrc_y; ++iy) { - _mm_storeu_ps(d8 + 4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].d))); - } for (int k = 0; k < 4; ++k) { auto qh = (const uint32_t *)x[4*ib+k].qh; auto idxh = _mm_set_epi32(qh[1] >> 4, qh[1], qh[0] >> 4, qh[0]); @@ -3694,10 +3692,13 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 as int16_t auto sumi = _mm256_packs_epi32(sumi1, sumi2); #endif - sumi = _mm256_madd_epi16(scales, sumi); - acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[4*iy+k]), _mm256_cvtepi32_ps(sumi), acc[iy]); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, sumi)); } } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.y[iy][ib].d), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } } for (int iy = 0; iy < nrc_y; ++iy) { auto sumf = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); @@ -9177,7 +9178,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #ifdef HAVE_FANCY_SIMD mm.func16 = mul_mat_iq1_s_r4_q8_1<16>; #endif - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_K128; break; case GGML_TYPE_IQ1_M_R4: assert (ne00 % QK4_NL == 0); @@ -9192,7 +9193,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #ifdef HAVE_FANCY_SIMD mm.func16 = mul_mat_iq1_m_r4_q8_0<16>; #endif - expected_typeB = GGML_TYPE_Q8_0_X4; + expected_typeB = GGML_TYPE_Q8_K128; break; default: @@ -12072,7 +12073,7 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8<1, block_q8_1_x4> q8(info); + Q8<1, block_q8_K128> q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); int8x16_t qx[8]; @@ -12084,8 +12085,8 @@ static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const Dat auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr)); auto x = (const block_iq1_s_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { - auto scale_yd = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib].d+0)); - auto scale_ym = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib].d+4)); + auto scale_yd = vdupq_n_f32(q8.y[0][ib].d); + auto scale_ym = vmulq_f32(scale_yd, vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[0][ib].bsums)))); for (int k = 0; k < 4; ++k) { auto sas = vld1_u16(x[4*ib+k].qh); auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7)); @@ -12135,23 +12136,22 @@ static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const Dat template <int nrc_y> static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8<nrc_y, block_q8_1_x4> q8(info); + Q8<nrc_y, block_q8_K128> q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); uint8x16_t qx[8]; int32x4_t acc[nrc_y] = {}; auto ms = vdup_n_u16(0x8000); auto mask = vdupq_n_s8(0x03); - float d8[8*nrc_y]; + float d8[4*nrc_y]; for (int ix= 0; ix < nrc_x; ix += 4) { auto dptr = (const ggml_half *)((const char *)vx + ix*bx); auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr)); auto x = (const block_iq1_s_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = vld1q_f16((const float16_t *)q8.y[iy][ib].d); - vst1q_f32(d8+8*iy+0, vcvt_f32_f16(vget_low_f16(scales))); - vst1q_f32(d8+8*iy+4, vcvt_f32_f16(vget_high_f16(scales))); + auto scales = vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[iy][ib].bsums))); + vst1q_f32(d8+4*iy, vmulq_f32(vdupq_n_f32(q8.y[iy][ib].d), scales)); } for (int k = 0; k < 4; ++k) { auto sas = vld1_u16(x[4*ib+k].qh); @@ -12193,8 +12193,8 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[6]), y.val[1], 2); sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[7]), y.val[1], 3); sumi = vmulq_s32(scales, sumi); - acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+0]), vcvtq_f32_s32(sumi)); - acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+4]), delta4); + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(sumi)); + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), delta4); } } } @@ -12208,25 +12208,21 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI template <int nrc_y> static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8<nrc_y, block_q8_0_x4> q8(info); + Q8<nrc_y, block_q8_K128> q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); int8x16_t qx[8]; - int32x4_t acc[nrc_y] = {}; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; auto shuffle0 = uint32x4_t{0x00000000, 0x01010101, 0x02020202, 0x03030303}; auto step = vdupq_n_u8(4); auto ms = vdupq_n_u8(0x08); auto mask = vdupq_n_s8(0x18); - float d8[4*nrc_y]; for (int ix= 0; ix < nrc_x; ix += 4) { auto dptr = (const ggml_half *)((const char *)vx + ix*bx); auto d1 = vmulq_f32(vdupq_n_f32(0.125f), vcvt_f32_f16(vld1_f16((const float16_t *)dptr))); auto x = (const block_iq1_m_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = vld1_f16((const float16_t *)q8.y[iy][ib].d); - vst1q_f32(d8+4*iy, vcvt_f32_f16(scales)); - } for (int k = 0; k < 4; ++k) { auto scales4 = vdup_n_u32(((const uint32_t *)x[4*ib+k].scales)[0]); scales4 = vand_u8(vshl_u32(scales4, int32x2_t{0, -4}), vdup_n_u8(0xf)); @@ -12272,10 +12268,13 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[5]), y.val[1], 1); sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[6]), y.val[1], 2); sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[7]), y.val[1], 3); - auto sumi = vmlaq_s32(vmlaq_s32(vdupq_n_s32(0), sumi1, scales1), sumi2, scales2); - acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), vcvtq_f32_s32(sumi)); + isum[iy] = vmlaq_s32(vmlaq_s32(isum[iy], sumi1, scales1), sumi2, scales2); } } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, vmulq_f32(d1, acc[iy])); @@ -13907,12 +13906,12 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1); m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1; m.func16 = mul_mat_iq1_s_r4_q8_1<16>; - expected_Btype = GGML_TYPE_Q8_1_X4; + expected_Btype = GGML_TYPE_Q8_K128; break; case GGML_TYPE_IQ1_M_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_m_r4_q8_0); m.func16 = mul_mat_iq1_m_r4_q8_0<16>; - expected_Btype = GGML_TYPE_Q8_0_X4; + expected_Btype = GGML_TYPE_Q8_K128; break; case GGML_TYPE_IQ3_XXS_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k); |