diff options
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 70 |
1 files changed, 63 insertions, 7 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 02458ac5..3a81d3ac 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -825,13 +825,6 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { auto val256 = MM256_SET_M128I(val128, val128); return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); } - //inline __m128i make_scales(__mmask16 signs, const uint8_t * scales_l) const { - // uint64_t aux64; std::memcpy(&aux64, scales_l, 8); - // auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); - // scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1); - // const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(x[i].scales_h), sign_mask), sign_mask); - // return _mm_mask_sub_epi8(scl, signs, _mm_setzero_si128(), scl); - //} inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const { uint64_t aux64; std::memcpy(&aux64, scales_l, 8); auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); @@ -3905,6 +3898,66 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> { float d; }; +struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { + DequantizerIQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_h, x[i].scales_l), q8, acc); + } + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + if (j == 0) { + hbits = vld1q_u8_x2(x[i].qh); + } + else { + hbits.val[0] = vshrq_n_u8(hbits.val[0], 4); + hbits.val[1] = vshrq_n_u8(hbits.val[1], 4); + } + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hmask)); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hmask)); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hmask)); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hmask)); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hmask)); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hmask)); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 1), hmask)); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 1), hmask)); + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]); + bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]); + } + } + inline int8x16_t make_scales(uint16_t sign_bits, const uint8_t * scales_l) const { + uint8x8_t aux = vld1_u8(scales_l); + uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf)); + int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(vshlq_n_u8(scl8, 1)), vdupq_n_s8(1)); + uint8x16_t signs = vceqq_u8(vandq_u8(vreinterpretq_u8_u16(vdupq_n_u16(sign_bits)), sign_mask), sign_mask); + signs = vorrq_u8(signs, vdupq_n_u8(1)); + // scales are 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 + // signs are 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15 + scales = vmulq_s8(scales, vreinterpretq_s8_u8(vqtbl1q_u8(signs, sign_shuffle))); + return vqtbl1q_s8(scales, hshuff); + } + inline static uint8x16_t load_sign_shuffle() { + static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; + return vld1q_u8(k_shuff); + } + + Q2bits bits; + uint8x16x2_t hbits; + const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x2f1c0d01f6e9d8c1)); + const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06}); + const uint8x16_t hmask = vdupq_n_u8(4); + const uint8x16_t sign_mask = vreinterpretq_u8_u64(uint64x2_t{0x0808040402020101, 0x8080404020201010}); + const uint8x16_t sign_shuffle = load_sign_shuffle(); + + float d; +}; + struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> { static int8x16_t load_values() { @@ -5240,6 +5293,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ2_K: MulMat::set_functions<DequantizerIQ2K>(m); break; + case GGML_TYPE_IQ3_K: + MulMat::set_functions<DequantizerIQ3K>(m); + break; case GGML_TYPE_IQ2_XXS: MulMat::set_functions<DequantizerIQ2XXS>(m); break; |