diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 269 |
1 files changed, 225 insertions, 44 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index cf4bd7ab..ca75e0fd 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -204,6 +204,7 @@ struct MulMat { case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ3_XXS_R4: + case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_BN_R4: return 4; case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_BF16_R16: return 16; @@ -3981,6 +3982,136 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat } } +#ifdef HAVE_FANCY_SIMD +// Strangely enough, the following implementation makes PP ~6% slower and TG ~6% faster +// compared to the vanilla AVX2 version below. +struct IndexHelperIQ3S { + union index_t { + __m256i vec; + uint16_t val[16]; + }; + inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const { + auto idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); + const __mmask16 * m16 = (const __mmask16 *)qh; + index_t idx; + idx.vec = _mm256_mask_add_epi16(idx_l, m16[0], idx_l, offset); + values[0] = _mm256_set_epi32(iq3s_grid[idx.val[ 7]], iq3s_grid[idx.val[ 6]], iq3s_grid[idx.val[ 5]], iq3s_grid[idx.val[ 4]], + iq3s_grid[idx.val[ 3]], iq3s_grid[idx.val[ 2]], iq3s_grid[idx.val[ 1]], iq3s_grid[idx.val[ 0]]); + values[1] = _mm256_set_epi32(iq3s_grid[idx.val[15]], iq3s_grid[idx.val[14]], iq3s_grid[idx.val[13]], iq3s_grid[idx.val[12]], + iq3s_grid[idx.val[11]], iq3s_grid[idx.val[10]], iq3s_grid[idx.val[ 9]], iq3s_grid[idx.val[ 8]]); + } + const __m256i offset = _mm256_set1_epi16(256); +}; +#else +struct IndexHelperIQ3S { + union index_t { + __m256i vec; + uint32_t val[8]; + }; + inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const { + index_t idx; + auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs)); + auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask); + idx.vec = _mm256_or_si256(idx_h, idx_l); + values[0] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], + iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); + idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs+8))); + idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask); + idx.vec = _mm256_or_si256(idx_h, idx_l); + values[1] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], + iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); + } + const __m256i idx_mask = _mm256_set1_epi32(256); + const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); +}; +#endif + +template <int nrc_y> +static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + int nbl = n / QK_K; + auto smask = _mm256_set1_epi8(1); + union { __m256i vec; uint32_t val[8]; } helper; + union { __m128i vec; uint16_t val[8]; } hidx; + __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; + __m256i qx[4]; +#ifdef HAVE_FANCY_SIMD + __mmask32 mask[4]; +#endif + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto qs = iq3[ibl].qs; + auto qh = iq3[ibl].qh; + auto scale_bits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales); + auto scales8 = MM256_SET_M128I(_mm_srli_epi16(scale_bits, 4), scale_bits); + helper.vec = _mm256_or_si256(_mm256_slli_epi16(_mm256_and_si256(scales8, _mm256_set1_epi8(0xf)), 1), _mm256_set1_epi8(1)); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto qh32 = (const uint32_t *)qh; + auto idx_h = _mm_sllv_epi64(_mm_cvtepu8_epi16(_mm_set1_epi32(qh32[0])), _mm_set_epi64x(4, 8)); + for (int i = 0; i < 4; ++i) { + auto idx_l = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)(qs + 8*i))); + hidx.vec = _mm_or_si128(idx_l, _mm_and_si128(idx_h, _mm_set1_epi16(0x100))); idx_h = _mm_srli_epi16(idx_h, 1); + qx[i] = _mm256_set_epi32(iq3s_grid[hidx.val[7]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[4]], + iq3s_grid[hidx.val[3]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[0]]); + } + qs += 32; qh += 4; + auto signs128 = _mm_loadu_si128((const __m128i*)iq3[ibl].signs + ib); + auto signs = MM256_SET_M128I(_mm_srli_epi16(signs128, 4), signs128); +#ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_set1_epi32(helper.val[ib])); + mask[0] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); + mask[1] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); + mask[2] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); + mask[3] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi = _mm256_setzero_si256(); + auto ys = _mm256_shuffle_epi32(y, 0x00); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_mask_sub_epi8(ys, mask[0], _mm256_setzero_si256(), ys)); + ys = _mm256_shuffle_epi32(y, 0x55); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_mask_sub_epi8(ys, mask[1], _mm256_setzero_si256(), ys)); + ys = _mm256_shuffle_epi32(y, 0xaa); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_mask_sub_epi8(ys, mask[2], _mm256_setzero_si256(), ys)); + ys = _mm256_shuffle_epi32(y, 0xff); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_mask_sub_epi8(ys, mask[3], _mm256_setzero_si256(), ys)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(sumi, scales)); + } +#else + auto scales16 = _mm256_cvtepi8_epi16(_mm_set1_epi32(helper.val[ib])); + auto scales = _mm256_unpacklo_epi16(scales16, scales16); + auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); + auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); + auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); + auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), s1))); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), s2))); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), s3))); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), s4))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, sumi)); + } +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} + template <int nrc_y> static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -5785,50 +5916,6 @@ static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataIn #endif } -//#ifdef HAVE_FANCY_SIMD -// Strangely enough, the following implementation makes PP ~6% slower and TG ~6% faster -// compared to the vanilla AVX2 version below. -//struct IndexHelperIQ3S { -// union index_t { -// __m256i vec; -// uint16_t val[16]; -// }; -// inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const { -// auto idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); -// const __mmask16 * m16 = (const __mmask16 *)qh; -// index_t idx; -// idx.vec = _mm256_mask_add_epi16(idx_l, m16[0], idx_l, offset); -// values[0] = _mm256_set_epi32(iq3s_grid[idx.val[ 7]], iq3s_grid[idx.val[ 6]], iq3s_grid[idx.val[ 5]], iq3s_grid[idx.val[ 4]], -// iq3s_grid[idx.val[ 3]], iq3s_grid[idx.val[ 2]], iq3s_grid[idx.val[ 1]], iq3s_grid[idx.val[ 0]]); -// values[1] = _mm256_set_epi32(iq3s_grid[idx.val[15]], iq3s_grid[idx.val[14]], iq3s_grid[idx.val[13]], iq3s_grid[idx.val[12]], -// iq3s_grid[idx.val[11]], iq3s_grid[idx.val[10]], iq3s_grid[idx.val[ 9]], iq3s_grid[idx.val[ 8]]); -// } -// const __m256i offset = _mm256_set1_epi16(256); -//}; -//#else -struct IndexHelperIQ3S { - union index_t { - __m256i vec; - uint32_t val[8]; - }; - inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const { - index_t idx; - auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs)); - auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask); - idx.vec = _mm256_or_si256(idx_h, idx_l); - values[0] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], - iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); - idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs+8))); - idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask); - idx.vec = _mm256_or_si256(idx_h, idx_l); - values[1] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], - iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); - } - const __m256i idx_mask = _mm256_set1_epi32(256); - const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); -}; -//#endif - struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { DequantizerIQ3S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} @@ -7438,6 +7525,19 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.func16 = mul_mat_iq3_xxs_r4_q8_k<16>; expected_typeB = GGML_TYPE_Q8_K; break; + case GGML_TYPE_IQ3_S_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq3_s_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq3_s_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq3_s_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq3_s_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq3_s_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq3_s_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq3_s_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq3_s_r4_q8_k<8>; + mm.func16 = mul_mat_iq3_s_r4_q8_k<16>; + expected_typeB = GGML_TYPE_Q8_K; + break; case GGML_TYPE_Q2_K_R4: assert (ne00 % QK_K == 0); mm.funcs[0] = mul_mat_q2_k_r4_q8_k<1>; @@ -10547,6 +10647,82 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat } } +template <int nrc_y> +static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<nrc_y, block_q8_K> q8(info); + int nbl = n / QK_K; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + int8x16_t qx[8]; + auto m1 = vdupq_n_u8(1); + auto shuff = vreinterpretq_u8_u32(uint32x4_t{0xffffff00, 0xffffff01, 0xffffff02, 0xffffff03}); + uint32_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)); + auto qs = iq3[ibl].qs; + auto qh = iq3[ibl].qh; + auto scale_bits = vld1q_u8(iq3[ibl].scales); + uint8x16x2_t scales8 = { vandq_u8(scale_bits, vdupq_n_u8(0xf)), vshrq_n_u8(scale_bits, 4) }; + scales8.val[0] = vorrq_u8(vshlq_n_u8(scales8.val[0], 1), m1); + scales8.val[1] = vorrq_u8(vshlq_n_u8(scales8.val[1], 1), m1); + vst1q_u8_x2((uint8_t *)stored_scales, scales8); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto signs128 = vld1q_u8(iq3[ibl].signs+16*ib); + if constexpr (nrc_y == 1) { + auto qh32 = (const uint32_t *)qh; + auto idx_h = vreinterpretq_u16_u64(vshlq_u64(vreinterpretq_u64_u16(vmovl_u8(vreinterpret_u8_u32(vdup_n_u32(qh32[0])))), int64x2_t{8, 4})); + union { uint16x8_t vec; uint16_t val[8]; } hidx; + for (int i = 0; i < 4; ++i) { + auto idx_l = vmovl_u8(vld1_u8(qs)); + hidx.vec = vorrq_u16(idx_l, vandq_u16(idx_h, vdupq_n_u16(0x100))); idx_h = vshrq_n_u16(idx_h, 1); + qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[0]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[3]]}); + auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1)); + qx[2*i+0] = vmulq_s8(qx[2*i+0], signs); + qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[4]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[7]]}); + signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1)); + qx[2*i+1] = vmulq_s8(qx[2*i+1], signs); + signs128 = vshrq_n_u8(signs128, 1); + qs += 8; + } + } else { + for (int i = 0; i < 4; ++i) { + qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[0] | ((qh[0] << (8-i)) & 0x100)], iq3s_grid[qs[1] | ((qh[1] << (8-i)) & 0x100)], + iq3s_grid[qs[2] | ((qh[2] << (8-i)) & 0x100)], iq3s_grid[qs[3] | ((qh[3] << (8-i)) & 0x100)]}); + auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1)); + qx[2*i+0] = vmulq_s8(qx[2*i+0], signs); + + qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[4] | ((qh[0] << (4-i)) & 0x100)], iq3s_grid[qs[5] | ((qh[1] << (4-i)) & 0x100)], + iq3s_grid[qs[6] | ((qh[2] << (4-i)) & 0x100)], iq3s_grid[qs[7] | ((qh[3] << (4-i)) & 0x100)]}); + signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1)); + qx[2*i+1] = vmulq_s8(qx[2*i+1], signs); + + qs += 8; + signs128 = vshrq_n_u8(signs128, 1); + } + } + auto scales = vreinterpretq_s32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(vdupq_n_u32(stored_scales[ib])), shuff)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qh += 4; + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + template <int nrc_y, int k_shift> inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra, int32x4_t * isum) { @@ -11864,6 +12040,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>; expected_Btype = GGML_TYPE_Q8_K; break; + case GGML_TYPE_IQ3_S_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_s_r4_q8_k); + m.func16 = mul_mat_iq3_s_r4_q8_k<16>; + expected_Btype = GGML_TYPE_Q8_K; + break; case GGML_TYPE_Q2_K_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_q2_k_r4_q8_k); expected_Btype = GGML_TYPE_Q8_K; |