diff options
Diffstat (limited to 'ggml/src')
| -rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 57 | 
1 files changed, 57 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 364d9872..0653b654 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -798,6 +798,59 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {      const __m128i m15 = _mm_set1_epi8(-15);  }; +struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> { +    DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(IQXKScales(4, -64)), values(load_values()) {} +    template <typename Q8> +    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { +        d = GGML_FP16_TO_FP32(x[i].d); +        prepare(x[i].qs, x[i].qh); +        iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_h, x[i].scales_l), q8, accm, scales); +    } +    inline void prepare(const uint8_t * q2, const uint8_t * qh) { +        bits.prepare(q2); +        auto h256 = _mm256_loadu_si256((const __m256i *)qh); +        auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 1), 1); +        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), hmask)); +        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, hmask)); +        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), hmask)); +        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), hmask)); +        bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]); +        bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]); +        bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]); +        bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]); +    } +    static inline __m512i load_values() { +        static const uint8_t kvalues_iq3nl[16] = {1, 24, 41, 54, 65, 77, 92, 111, 5, 28, 45, 58, 69, 81, 96, 115}; +        auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq3nl); +        auto val256 = MM256_SET_M128I(val128, val128); +        return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); +    } +    //inline __m128i make_scales(__mmask16 signs, const uint8_t * scales_l) const { +    //    uint64_t aux64; std::memcpy(&aux64, scales_l, 8); +    //    auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); +    //    scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1); +    //    const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(x[i].scales_h), sign_mask), sign_mask); +    //    return _mm_mask_sub_epi8(scl, signs, _mm_setzero_si128(), scl); +    //} +    inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const { +        uint64_t aux64; std::memcpy(&aux64, scales_l, 8); +        auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf)); +        scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1); +        const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(signs), sign_mask), sign_mask); +        const __m128i sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff); +        return _mm_sign_epi8(scl, sch); +    } +    Q2Bits bits; +    const IQXKScales iqxk; + +    const __m512i values; +    const __m512i hmask = _mm512_set1_epi8(4); +    const __m128i m1 = _mm_set1_epi8(1); +    const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101); +    const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff); +    constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; +}; +  struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {      DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -128), values(load_iq4nl_values_512()) {}      template <typename Q8> @@ -3088,6 +3141,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {              assert (ne00 % QK_K == 0);              MulMat::set_functions<DequantizerIQ2K>(mm);              break; +        case GGML_TYPE_IQ3_K: +            assert (ne00 % QK_K == 0); +            MulMat::set_functions<DequantizerIQ3K>(mm); +            break;          case GGML_TYPE_IQ4_K:              assert (ne00 % QK_K == 0);              MulMat::set_functions<DequantizerIQ4K>(mm);  | 
