diff options
Diffstat (limited to 'ggml/src/iqk/iqk_gemm_iquants.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_gemm_iquants.cpp | 238 |
1 files changed, 235 insertions, 3 deletions
diff --git a/ggml/src/iqk/iqk_gemm_iquants.cpp b/ggml/src/iqk/iqk_gemm_iquants.cpp index 782e48d8..9e2ddc10 100644 --- a/ggml/src/iqk/iqk_gemm_iquants.cpp +++ b/ggml/src/iqk/iqk_gemm_iquants.cpp @@ -87,13 +87,12 @@ struct EvenSignHelper { const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0); const __m256i mask = _mm256_set1_epi32(127); const __m256i mone = _mm256_set1_epi32(1); -#else +#endif inline void sign_value(uint32_t aux32, __m256i& value) const { auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127], keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]); value = _mm256_sign_epi8(value, signs); } -#endif }; struct SignHelper { @@ -144,6 +143,35 @@ struct SignHelper { const __m256i mone = _mm256_set1_epi8(1); }; +// for (int i = 0; i < nb; ++i) { +// +// __m256i sumi[nrc_y], all_scales; +// //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256(); +// __m256i mins; +// float dmin = deq.new_block(i, &all_scales, mins); +// for (int iy = 0; iy < nrc_y; ++iy) { +// auto bsums = q8.load_bsums(iy, i); +// auto prod = _mm256_madd_epi16(mins, bsums); +// accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]); +// } +// +// for (int j = 0; j < QK_K/128; ++j) { +// deq.prepare(i, j); +// set_scales_8(&all_scales, j, scales); +// //multiply_add_iq(deq.bits, scales, j, i, q8, sumi); +// multiply_add(deq.bits, scales, j, i, q8, sumi); +// } +// for (int iy = 0; iy < nrc_y; ++iy) { +// const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i)); +// accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); +// } +// } +// +// for (int iy = 0; iy < nrc_y; ++iy) { +// info.store(ix, iy, hsum_float_8(accd[iy])); +// } +// } + struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> { DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} @@ -165,6 +193,16 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> { auto sc16 = load_scales(i); scales[0] = MM256_SET_M128I(sc16, sc16); } + inline void new_block_f(int i, __m256 * scales) { + auto sc16 = load_scales(i); + auto scf = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(sc16))); + auto scf_l = _mm256_castps256_ps128(scf); + auto scf_h = _mm256_extractf128_ps(scf, 1); + scales[0] = _mm256_set_m128(scf_l, scf_l); + scales[1] = _mm256_set_m128(scf_h, scf_h); + scales[2] = _mm256_mul_ps(scf, _mm256_set1_ps(-minv)); + } + inline float new_block(int i, __m256i * scales, __m256i& mins) { auto sc16 = load_scales(i); mins = scb.shuffle(sc16); @@ -730,6 +768,130 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data } template <typename Dequantizer, int nrc_y> +static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + static_assert(Dequantizer::num_blocks == 8); + const int nb = n / QK_K; + Q8<nrc_y, block_q8_2_x4> q8(info); + Dequantizer deq(vx, bx); + __m256 scales[3]; + __m256 accd[nrc_y]; + __m256i sumi[4]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + deq.new_block_f(i, scales); + for (int iy = 0; iy < nrc_y; ++iy) { + auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4))); + auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4))); + auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16)); + accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]); + } + + for (int j = 0; j < QK_K/128; ++j) { + deq.prepare(i, j); + auto& values = deq.bits.values; + for (int iy = 0; iy < nrc_y; ++iy) { + auto qs = q8.y[iy][2*i+j].qs; +#ifdef HAVE_FANCY_SIMD + sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[0], _mm256_loadu_si256((const __m256i*)qs+0)); + sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[1], _mm256_loadu_si256((const __m256i*)qs+1)); + sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[2], _mm256_loadu_si256((const __m256i*)qs+2)); + sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[3], _mm256_loadu_si256((const __m256i*)qs+3)); +#else + sumi[0] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[0], _mm256_loadu_si256((const __m256i*)qs+0))); + sumi[1] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[1], _mm256_loadu_si256((const __m256i*)qs+1))); + sumi[2] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[2], _mm256_loadu_si256((const __m256i*)qs+2))); + sumi[3] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[3], _mm256_loadu_si256((const __m256i*)qs+3))); +#endif + sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1])); + sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3])); + sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2])); + auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16)); + auto dy = _mm256_set_m128(d4, d4); + accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + +template <int nrc_y> +static void mul_mat_iq2_xxs_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const int nb = n / QK_K; + __m256 scales[2]; + __m256 accd[nrc_y]; + __m256i sumi[4]; + __m256i xv[4]; + EvenSignHelper esh; + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + const block_iq2_xxs * x = (const block_iq2_xxs *)((const char *)vx + ix*bx); + + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d)*0.125f; + const uint16_t * a16 = x[i].qs; + auto sc16 = _mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]); + sc16 = _mm_or_si128(_mm_slli_epi16(_mm_srli_epi16(sc16, 12), 1), _mm_set1_epi16(1)); + auto sc32 = _mm256_cvtepi16_epi32(sc16); + auto all_scales = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sc32)); + auto all_mins = _mm256_mul_ps(all_scales, _mm256_set1_ps(-43.f)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = (const block_q8_2_x4 *)info.src1_row(iy); + auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(y[2*i+0].d + 4))); + auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(y[2*i+1].d + 4))); + auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16)); + accd[iy] = _mm256_fmadd_ps(all_mins, my, accd[iy]); + } + auto scales_l = _mm256_castps256_ps128(all_scales); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[0] = _mm256_set_m128(scales_l, scales_l); + scales[1] = _mm256_set_m128(scales_h, scales_h); + + for (int j = 0; j < QK_K/128; ++j) { + const uint8_t * a8 = (const uint8_t *)(a16 + 16*j); + for (int k = 0; k < 4; ++k) { + auto a8k = a8 + 8*k; + xv[k] = _mm256_set_epi64x(iq2xxs_grid[a8k[3]], iq2xxs_grid[a8k[2]], iq2xxs_grid[a8k[1]], iq2xxs_grid[a8k[0]]); + uint32_t aux32; std::memcpy(&aux32, a8k+4, sizeof(uint32_t)); + esh.sign_value(aux32, xv[k]); + xv[k] = _mm256_add_epi8(xv[k], _mm256_set1_epi8(43)); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = (const block_q8_2_x4 *)info.src1_row(iy); + sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[0], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+0)); + sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[1], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+1)); + sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[2], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+2)); + sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[3], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+3)); + sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1])); + sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3])); + sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2])); + auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)y[2*i+j].d)), 16)); + auto dy = _mm256_set_m128(d4, d4); + accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + +template <typename Dequantizer, int nrc_y> static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); #ifdef HAVE_FANCY_SIMD @@ -1560,6 +1722,55 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI } } +void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq2_xxs * x8[8]; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + ggml_half dh[8]; + uint16_t all_ls[64]; + EvenSignHelper esh; + + uint32_t block[8]; + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_xxs *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + // TODO: simdify + for (int k = 0; k < 8; ++k) { + dh[k] = x8[k][i].d; + for (int ib32 = 0; ib32 < 8; ++ib32) { + std::memcpy(aux32, x8[k][i].qs + 4*ib32, 2*sizeof(uint32_t)); + all_ls[8*ib32 + k] = (2*(aux32[1] >> 28) + 1); + auto value = _mm256_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + esh.sign_value(aux32[1], value); + _mm256_storeu_si256((__m256i *)block, value); + auto qs = (uint32_t *)y[ib32].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + auto vd = _mm256_mul_ps(_mm256_set1_ps(0.125f), _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh))); + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32); + auto iscales32 = _mm256_cvtepi16_epi32(iscales16); + auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32)); + _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + } + y += QK_K/32; + } + } +} + template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) { funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>; funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>; @@ -1575,7 +1786,19 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { - if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) { + if (ne00%QK_K != 0) return false; + + if (ggml_type(typeA) == GGML_TYPE_IQ2_XXS) { + if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) { + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ2XXS, kernels); + //IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_q8_2_IQ_N, kernels); + func16 = nullptr; + return true; + } + return false; + } + + if (ggml_type(typeB) != GGML_TYPE_Q8_K) { return false; } @@ -1629,6 +1852,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_ } +bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) { + if (n%QK_K != 0 || nrc_x%8 != 0) return false; + switch (ggml_type(type)) { + case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break; + default: return false; + } + return true; +} + #else // --------------------------------------- __aarch64__ --------------------------------------------- |