diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 317 |
1 files changed, 303 insertions, 14 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3a81d3ac..db83b841 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -692,6 +692,16 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> { }; +struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> { + DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + template <typename Q8> + inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) { + d = GGML_FP16_TO_FP32(x[i].d); + bits.prepare(x[i].qs); + } + Q2Bits bits; +}; + struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> { DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template <typename Q8> @@ -960,6 +970,16 @@ inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } +template <typename Q8> +inline void compute_block_iq2tn(int iy, int i, float d, const Q8& q8, const __m512i * values, __m512 * accd) { + auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i)); + auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32( + _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0), + values[0], q8.load_quants64(iy, i, 0)), values[1], q8.load_quants64(iy, i, 1)), + values[2], q8.load_quants64(iy, i, 2)), values[3], q8.load_quants64(iy, i, 3)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); +} + template <typename Dequantizer, int nrc_y> static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); @@ -985,14 +1005,22 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da deq.new_block(i, q8, accm, scales); for (int iy = 0; iy < nrc_y; ++iy) { - //compute_block(iy, i, deq.d, q8, deq.bits.values, scales, accd); - const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0)); - const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1)); - const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2)); - const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3)); - auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); - sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); - accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); + if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) { + auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i)); + auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32( + _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0), + deq.bits.values[0], q8.load_quants64(iy, i, 0)), deq.bits.values[1], q8.load_quants64(iy, i, 1)), + deq.bits.values[2], q8.load_quants64(iy, i, 2)), deq.bits.values[3], q8.load_quants64(iy, i, 3)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); + } else { + const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0)); + const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1)); + const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2)); + const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3)); + auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); + sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); + accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); + } } } @@ -1034,19 +1062,33 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx); - for (int kx = 0; kx < k_nx; ++kx) { - compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd); + if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) { + for (int kx = 0; kx < k_nx; ++kx) { + compute_block_iq2tn(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, &accd); + } + } else { + for (int kx = 0; kx < k_nx; ++kx) { + compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd); + } } } if (2*(nb/2) < nb) { int i0 = 2*(nb/2); deq[0]->new_block(i0, q8, &accm, scales); - compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd); + if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) { + compute_block_iq2tn(0, i0, deq[0]->d, q8, deq[0]->bits.values, &accd); + } else { + compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd); + } } - auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); - info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); + if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) { + info.store(ix, 0, _mm512_reduce_add_ps(accd)); + } else { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); + info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); + } } } @@ -1439,6 +1481,74 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> { const __m256i mh = _mm256_set1_epi8(0x30); }; +struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> { + DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + inline void new_block(int i) { + d = GGML_FP16_TO_FP32(x[i].d); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + } + + Q2Bits bits; +}; + + +template <int nrc_y> +IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Q8<nrc_y> q8(info); + DequantizerIQ2TN deq(vx, bx); + + __m256 accd[nrc_y]; + const auto m1 = _mm256_set1_epi16(1); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + __m256i sumi[nrc_y]; + deq.new_block(i); + + deq.prepare(i, 0); + for (int iy = 0; iy < nrc_y; ++iy) { + sumi[iy] = _mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)), + _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1))); + sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 2)), + _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 3))), sumi[iy]); + } + deq.prepare(i, 1); + for (int iy = 0; iy < nrc_y; ++iy) { + sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 4)), + _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 5))), sumi[iy]); + sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 6)), + _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 7))), sumi[iy]); + sumi[iy] = _mm256_sub_epi16(sumi[iy], q8.load_bsums(iy, i)); + } + if (i > 0) { + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]); + } + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_mul_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy]))); + } + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + + } +} + template <typename Dequantizer, int nrc_y> static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -1931,7 +2041,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)), _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4)))); - accd[iy] = _mm256_add_epi32(dot, accd[iy]); + accd[iy] = i > 0 ? _mm256_add_epi32(dot, accd[iy]) : dot; #endif } } @@ -3156,6 +3266,21 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { assert (ne00 % QK_K == 0); MulMat::set_functions<DequantizerQ2K>(mm); break; + case GGML_TYPE_IQ2_TN: + assert (ne00 % QK_K == 0); +#ifdef HAVE_FANCY_SIMD + MulMat::set_functions<DequantizerIQ2TN>(mm); +#else + mm.funcs[0] = mul_mat_iq2tn_q8_K<1>; + mm.funcs[1] = mul_mat_iq2tn_q8_K<2>; + mm.funcs[2] = mul_mat_iq2tn_q8_K<3>; + mm.funcs[3] = mul_mat_iq2tn_q8_K<4>; + mm.funcs[4] = mul_mat_iq2tn_q8_K<5>; + //mm.funcs[5] = mul_mat_iq2tn_q8_K<6>; + //mm.funcs[6] = mul_mat_iq2tn_q8_K<7>; + //mm.funcs[7] = mul_mat_iq2tn_q8_K<8>; +#endif + break; case GGML_TYPE_Q3_K: assert (ne00 % QK_K == 0); MulMat::set_functions<DequantizerQ3K>(mm); @@ -4280,6 +4405,159 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { }; +struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> { + DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return true; } + + //template <typename Q8> + //inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) { + // d = GGML_FP16_TO_FP32(x[i].d); + //} + + inline void new_block(int i) { + d = GGML_FP16_TO_FP32(x[i].d); + } + + template <typename Q8> + inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); + + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); + } + } + template <typename Q8> + inline void compute1(const Q8& q8, int i, int j, int32x4_t * sumi) { + auto q8b_1 = q8.load_quants(0, i, 4*j+0); + sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(0, i, 4*j+1); + sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + q8b_1 = q8.load_quants(0, i, 4*j+2); + sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_1.val[1]); + + q8b_2 = q8.load_quants(0, i, 4*j+3); + sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_2.val[1]); + } + + IQK_ALWAYS_INLINE void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + auto m1 = vdupq_n_s8(1); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vsubq_s8(bits.b1.val[k], m1); + bits.b2.val[k] = vsubq_s8(bits.b2.val[k], m1); + } + } + + Q2bits bits; + + float d; +}; + +template <int nrc_y> +void mul_mat_iq2tn_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<nrc_y, block_q8_K> q8(info); + + DequantizerIQ2TN deq(vx, bx, nrc_y); + float32x4_t acc[nrc_y]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); + + deq.new_block(i); + deq.prepare(i, 0); + deq.compute(q8, i, 0, sumi); + deq.prepare(i, 1); + deq.compute(q8, i, 1, sumi); + + if (i > 0) { + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmulq_f32(vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} +void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<1, block_q8_K> q8(info); + + DequantizerIQ2TN deq(vx, bx, 1); + + auto m1 = vdup_n_s16(-1); + float32x4_t acc[2]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[2] = {}; + deq.new_block(i); + auto bsums = q8.load_bsums(0, i); + bsums.val[0] = vaddq_s32(bsums.val[0], bsums.val[1]); + sumi[0] = vmlal_s16(sumi[0], vget_low_s16 (bsums.val[0]), m1); + sumi[1] = vmlal_s16(sumi[1], vget_high_s16(bsums.val[0]), m1); + deq.bits.prepare(deq.x[i].qs); + deq.compute1(q8, i, 0, sumi); + deq.bits.prepare(deq.x[i].qs+32); + deq.compute1(q8, i, 1, sumi); + + auto vd = vdupq_n_f32(deq.d*q8.scale(0, i)); + if (i > 0) { + acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd); + acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd); + } else { + acc[0] = vmulq_f32(vcvtq_f32_s32(sumi[0]), vd); + acc[1] = vmulq_f32(vcvtq_f32_s32(sumi[1]), vd); + } + + } + + acc[0] = vaddq_f32(acc[0], acc[1]); + info.store(ix, 0, vaddvq_f32(acc[0])); + } +} + template <int nrc_y, typename Dequantizer> void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -5269,6 +5547,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_Q2_K: MulMat::set_functions<DequantizerQ2K>(m); break; + case GGML_TYPE_IQ2_TN: + //MulMat::set_functions<DequantizerIQ2TN>(m); + m.funcs[0] = mul_mat_iq2tn_K_q8_K_1; + m.funcs[1] = mul_mat_iq2tn_K_q8_K_T<2>; + m.funcs[2] = mul_mat_iq2tn_K_q8_K_T<3>; + m.funcs[3] = mul_mat_iq2tn_K_q8_K_T<4>; + m.funcs[4] = mul_mat_iq2tn_K_q8_K_T<5>; + m.funcs[5] = mul_mat_iq2tn_K_q8_K_T<6>; + m.funcs[6] = mul_mat_iq2tn_K_q8_K_T<7>; + m.funcs[7] = mul_mat_iq2tn_K_q8_K_T<8>; + break; case GGML_TYPE_Q3_K: MulMat::set_functions<DequantizerQ3K>(m); break; |