summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--iqk_mul_mat.cpp50
1 files changed, 48 insertions, 2 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index bd5d1430..54a2d567 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -2185,6 +2185,7 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
b[1] = make1(qs + 2);
b[2] = make1(qs + 4);
b[3] = make1(qs + 6);
+ // The following is actually slower
//auto bits = vld1q_u16(qs);
//auto vidx = vandq_u16(bits, vdupq_n_u16(511));
//const uint16_t * idx = (const uint16_t *)&vidx;
@@ -2275,11 +2276,53 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
SimpleBits bits;
- constexpr static const uint64x2_t scale_shuffle = { 0x0b030a0209010800, 0x0f070e060d050c04 };
+ float d;
+
+};
+
+struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
+ DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = 0.25f * GGML_FP16_TO_FP32(x[i].d);
+ gas = vld1q_u32_x2((const uint32_t *)(x[i].qs + QK_K/4));
+ int32x4x2_t scales;
+ scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(gas.val[0], 28), 1), vdupq_n_u32(1)));
+ scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(gas.val[1], 28), 1), vdupq_n_u32(1)));
+ return scales;
+ }
+
+ inline static void make2(const uint8_t * q3, uint32_t sidx, uint8x16_t * b) {
+ const uint64_t * signs = keven_signs;
+ b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]});
+ b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]});
+ auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127))));
+ auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 21) & 127))));
+ b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1));
+ b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2));
+ }
+ inline void prepare(int i, int j) {
+ const auto * q3 = x[i].qs + 32*j;
+ const auto * signs = (const uint32_t *)(gas.val + j);
+ for (int k = 0; k < 2; ++k) {
+ make2(q3, signs[k], bits.b1.val + 2*k); q3 += 8;
+ }
+ signs += 2;
+ for (int k = 0; k < 2; ++k) {
+ make2(q3, signs[k], bits.b2.val + 2*k); q3 += 8;
+ }
+ }
+
+ SimpleBits bits;
+ uint32x4x2_t gas;
float d;
- };
+};
template <int nrc_y, typename Dequantizer>
@@ -2835,6 +2878,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /
case GGML_TYPE_IQ2_S:
MulMat::set_functions<DequantizerIQ2S>(m);
break;
+ case GGML_TYPE_IQ3_XXS:
+ MulMat::set_functions<DequantizerIQ3XXS>(m);
+ break;
case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);