diff options
| -rw-r--r-- | iqk_mul_mat.cpp | 142 | 
1 files changed, 86 insertions, 56 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index bb41a33c..03641677 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1400,6 +1400,61 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {  }; +struct EvenSignHelper { +#ifdef _HAVE_FANCY_SIMD +    IQK_ALWAYS_INLINE void sign_2_values(__m256i aux, __m256i * values) const { +        aux = _mm256_and_si256(_mm256_srlv_epi32(aux, shifts), mask); +        //auto aux1 = _mm256_xor_si256(aux, _mm256_and_si256(_mm256_srli_epi16(aux), _mm256_set1_epi8(0xf))); +        //auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_shuffle_epi8(bhelper, aux1))); +        auto pcnt = _mm256_popcnt_epi32(aux); +        auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7))); +        const __mmask32 * m32 = (const __mmask32 *)&sign_bits; +        values[0] = _mm256_mask_sub_epi8(values[0], m32[0], _mm256_setzero_si256(), values[0]); +        values[1] = _mm256_mask_sub_epi8(values[1], m32[1], _mm256_setzero_si256(), values[1]); +    } +    IQK_ALWAYS_INLINE void sign_2_values(const uint32_t * aux32, __m256i * values) const { +        sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[2]), _mm_set1_epi32(aux32[0])), values); +    } +    IQK_ALWAYS_INLINE void sign_2_values(const uint16_t * aux16, __m256i * values) const { +        sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux16[2] | (aux16[3] << 16)), _mm_set1_epi32(aux16[0] | (aux16[1] << 16))), values); +    } +#else +    IQK_ALWAYS_INLINE void sign_value(uint32_t aux32, __m256i& value) const { +        auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127], +                                       keven_signs[(aux32 >>  7) & 127], keven_signs[(aux32 >>  0) & 127]); +        value = _mm256_sign_epi8(value, signs); +    } +    IQK_ALWAYS_INLINE void sign_2_values(const uint16_t * aux16, __m256i * values) const { +        sign_value(aux16[0] | (aux16[1] << 16), values[0]); +        sign_value(aux16[2] | (aux16[3] << 16), values[1]); +    } +#endif +    inline void sign_values(const uint32_t * aux32, __m256i * values) const { +#ifdef _HAVE_FANCY_SIMD +        sign_2_values(aux32+1, values+0); +        sign_2_values(aux32+5, values+2); +#else +        sign_value(aux32[1], values[0]); +        sign_value(aux32[3], values[1]); +        sign_value(aux32[5], values[2]); +        sign_value(aux32[7], values[3]); +#endif +    } +#ifdef _HAVE_FANCY_SIMD +    const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0); +    const __m256i mask   = _mm256_set1_epi32(127); +    const __m256i mone   = _mm256_set1_epi32(1); +    //const __m256i bhelper = load_bhelper(); +    //static __m256i load_bhelper() { +    //    static const uint8_t k_bit_helper[32] = { +    //        0x0, 0x8, 0x8, 0x0, 0x8, 0x0, 0x0, 0x8, 0x8, 0x0, 0x0, 0x8, 0x0, 0x8, 0x8, 0x0, +    //        0x0, 0x8, 0x8, 0x0, 0x8, 0x0, 0x0, 0x8, 0x8, 0x0, 0x0, 0x8, 0x0, 0x8, 0x8, 0x0, +    //    }; +    //    return _mm256_loadu_si256((const __m256i*)k_bit_helper); +    //} +#endif +}; +  struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {      DequantizerIQ3XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} @@ -1428,43 +1483,50 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {          return _mm256_set_epi32(iq3xxs_grid[qs[7]], iq3xxs_grid[qs[6]], iq3xxs_grid[qs[5]], iq3xxs_grid[qs[4]],                                  iq3xxs_grid[qs[3]], iq3xxs_grid[qs[2]], iq3xxs_grid[qs[1]], iq3xxs_grid[qs[0]]);      } -    inline static __m256i make_signs(const uint16_t * sidx) { -        uint32_t aux32 = sidx[0] | (sidx[1] << 16); -        return _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127], -                                 keven_signs[(aux32 >>  7) & 127], keven_signs[aux32 & 127]); -    } -    inline static __m256i make1(const uint8_t * qs, const uint16_t * sidx, __m256i& q8_quants) { -        q8_quants = _mm256_sign_epi8(q8_quants, make_signs(sidx)); -        return make_quants(qs); -    } -    inline static __m256i make1(const uint8_t * qs, const uint16_t * sidx, const __m256i& min_value) { -        auto val = make_quants(qs); -        auto s   = make_signs(sidx); -        return _mm256_add_epi8(_mm256_sign_epi8(val, s), min_value); +    inline static void make4_unsigned(const uint8_t * qs, __m256i * values) { +        values[0] = make_quants(qs+ 0); +        values[1] = make_quants(qs+ 8); +        values[2] = make_quants(qs+16); +        values[3] = make_quants(qs+24);      } +    //inline static __m256i make_signs(const uint16_t * sidx) { +    //    uint32_t aux32 = sidx[0] | (sidx[1] << 16); +    //    return _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127], +    //                             keven_signs[(aux32 >>  7) & 127], keven_signs[aux32 & 127]); +    //} +    //inline static __m256i make1(const uint8_t * qs, const uint16_t * sidx, __m256i& q8_quants) { +    //    q8_quants = _mm256_sign_epi8(q8_quants, make_signs(sidx)); +    //    return make_quants(qs); +    //} +    //inline static __m256i make1(const uint8_t * qs, const uint16_t * sidx, const __m256i& min_value) { +    //    auto val = make_quants(qs); +    //    auto s   = make_signs(sidx); +    //    return _mm256_add_epi8(_mm256_sign_epi8(val, s), min_value); +    //}      inline void prepare(int i, int j) {          auto qs = x[i].qs + 32*j;          const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j; -        bits.values[0] = make1(qs+ 0, signs+0, min_value); -        bits.values[1] = make1(qs+ 8, signs+2, min_value); -        bits.values[2] = make1(qs+16, signs+4, min_value); -        bits.values[3] = make1(qs+24, signs+6, min_value); +        make4_unsigned(qs, bits.values); +        esh.sign_2_values(signs+0, bits.values+0); +        esh.sign_2_values(signs+4, bits.values+2); +        for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_add_epi32(bits.values[k], min_value);      }      template <typename Q8>      inline void prepare(int i, int j, const Q8& q8, __m256i * q8_quants) { +        for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);          auto qs = x[i].qs + 32*j;          const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j; -        q8_quants[0] = q8.load_quants(0, i, 4*j+0); bits.values[0] = make1(qs+ 0, signs+0, q8_quants[0]); -        q8_quants[1] = q8.load_quants(0, i, 4*j+1); bits.values[1] = make1(qs+ 8, signs+2, q8_quants[1]); -        q8_quants[2] = q8.load_quants(0, i, 4*j+2); bits.values[2] = make1(qs+16, signs+4, q8_quants[2]); -        q8_quants[3] = q8.load_quants(0, i, 4*j+3); bits.values[3] = make1(qs+24, signs+6, q8_quants[3]); +        make4_unsigned(qs, bits.values); +        esh.sign_2_values(signs+0, q8_quants+0); +        esh.sign_2_values(signs+4, q8_quants+2);      }      constexpr static int minv = 64;      SimpleBits bits;      Scales8KBase scb; +    EvenSignHelper esh;      const __m256i min_value = _mm256_set1_epi8(minv);  }; @@ -1723,43 +1785,15 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {          values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]);          values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]);      } -#ifdef HAVE_FANCY_SIMD -    inline void sign_2_values(const uint32_t * aux32, __m256i * values) const { -        auto aux = MM256_SET_M128I(_mm_set1_epi32(aux32[2]), _mm_set1_epi32(aux32[0])); -        aux = _mm256_and_si256(_mm256_srlv_epi32(aux, shifts), mask); -        auto pcnt = _mm256_popcnt_epi32(aux); -        auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7))); -        const __mmask32 * m32 = (const __mmask32 *)&sign_bits; -        values[0] = _mm256_mask_sub_epi8(values[0], m32[0], _mm256_setzero_si256(), values[0]); -        values[1] = _mm256_mask_sub_epi8(values[1], m32[1], _mm256_setzero_si256(), values[1]); -    } -#else -    inline void sign_value(uint32_t aux32, __m256i& value) const { -        auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127], -                                       keven_signs[(aux32 >>  7) & 127], keven_signs[(aux32 >>  0) & 127]); -        value = _mm256_sign_epi8(value, signs); -    } -#endif -    inline void sign_values(const uint32_t * aux32, __m256i * values) const { -#ifdef HAVE_FANCY_SIMD -        sign_2_values(aux32+1, values+0); -        sign_2_values(aux32+5, values+2); -#else -        sign_value(data.val[1], values[0]); -        sign_value(data.val[3], values[1]); -        sign_value(data.val[5], values[2]); -        sign_value(data.val[7], values[3]); -#endif -    }      inline void make4_signed(const uint32_t * aux32, const __m256i& min_value, __m256i * values) const {          make4(aux32, values); -        sign_values(aux32, values); +        esh.sign_values(aux32, values);          for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);      }      inline void make4(const uint32_t * aux32, __m256i * values, __m256i * q8) const {          make4(aux32, values); -        sign_values(aux32, q8); +        esh.sign_values(aux32, q8);      }      inline void prepare(int i, int j) {          Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j); @@ -1775,13 +1809,9 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {      constexpr static int minv = 43;      SimpleBits bits;      Scales8KBase scb; +    EvenSignHelper esh;      const __m256i min_value = _mm256_set1_epi8(minv);      const __m256i shuffle = _mm256_set_epi32(7, 5, 3, 1, 7, 5, 3, 1); -#ifdef HAVE_FANCY_SIMD -    const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0); -    const __m256i mask   = _mm256_set1_epi32(127); -    const __m256i mone   = _mm256_set1_epi32(1); -#endif  };  //  | 
