summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp70
1 files changed, 63 insertions, 7 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 02458ac5..3a81d3ac 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -825,13 +825,6 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
auto val256 = MM256_SET_M128I(val128, val128);
return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
}
- //inline __m128i make_scales(__mmask16 signs, const uint8_t * scales_l) const {
- // uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
- // auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
- // scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1);
- // const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(x[i].scales_h), sign_mask), sign_mask);
- // return _mm_mask_sub_epi8(scl, signs, _mm_setzero_si128(), scl);
- //}
inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const {
uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
@@ -3905,6 +3898,66 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
float d;
};
+struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
+ DequantizerIQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_h, x[i].scales_l), q8, acc);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs+32*j);
+ if (j == 0) {
+ hbits = vld1q_u8_x2(x[i].qh);
+ }
+ else {
+ hbits.val[0] = vshrq_n_u8(hbits.val[0], 4);
+ hbits.val[1] = vshrq_n_u8(hbits.val[1], 4);
+ }
+ bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hmask));
+ bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hmask));
+ bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hmask));
+ bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hmask));
+ bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hmask));
+ bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hmask));
+ bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 1), hmask));
+ bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 1), hmask));
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
+ bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
+ }
+ }
+ inline int8x16_t make_scales(uint16_t sign_bits, const uint8_t * scales_l) const {
+ uint8x8_t aux = vld1_u8(scales_l);
+ uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
+ int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(vshlq_n_u8(scl8, 1)), vdupq_n_s8(1));
+ uint8x16_t signs = vceqq_u8(vandq_u8(vreinterpretq_u8_u16(vdupq_n_u16(sign_bits)), sign_mask), sign_mask);
+ signs = vorrq_u8(signs, vdupq_n_u8(1));
+ // scales are 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15
+ // signs are 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15
+ scales = vmulq_s8(scales, vreinterpretq_s8_u8(vqtbl1q_u8(signs, sign_shuffle)));
+ return vqtbl1q_s8(scales, hshuff);
+ }
+ inline static uint8x16_t load_sign_shuffle() {
+ static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
+ return vld1q_u8(k_shuff);
+ }
+
+ Q2bits bits;
+ uint8x16x2_t hbits;
+ const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x2f1c0d01f6e9d8c1));
+ const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
+ const uint8x16_t hmask = vdupq_n_u8(4);
+ const uint8x16_t sign_mask = vreinterpretq_u8_u64(uint64x2_t{0x0808040402020101, 0x8080404020201010});
+ const uint8x16_t sign_shuffle = load_sign_shuffle();
+
+ float d;
+};
+
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
static int8x16_t load_values() {
@@ -5240,6 +5293,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ2_K:
MulMat::set_functions<DequantizerIQ2K>(m);
break;
+ case GGML_TYPE_IQ3_K:
+ MulMat::set_functions<DequantizerIQ3K>(m);
+ break;
case GGML_TYPE_IQ2_XXS:
MulMat::set_functions<DequantizerIQ2XXS>(m);
break;