summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-05-28 10:58:30 +0200
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:49 +0300
commit221a2c38070040c679c56a7d4c598508d485a759 (patch)
tree13882c64630379fb8ae66fc7616c3dd7301bac2d
parent7dcca6aea746de02694ccd0d5fbeae73452533d1 (diff)
Simplify
-rw-r--r--iqk_mul_mat.cpp71
1 files changed, 31 insertions, 40 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 54a2d567..08f2bd47 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -2088,6 +2088,20 @@ const uint64_t kall_signs[256] = {
0xffffffffffff0101, 0xffffffffffff01ff, 0xffffffffffffff01, 0xffffffffffffffff,
};
+inline int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) {
+ int32x4x2_t scales;
+ scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1)));
+ scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v2, 28), 1), vdupq_n_u32(1)));
+ return scales;
+}
+
+inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) {
+ auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127))));
+ auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127))));
+ b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1));
+ b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2));
+}
+
struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
@@ -2104,35 +2118,22 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
data.val[2] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7
data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7
- int32x4x2_t scales;
- scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(data.val[1], 28), 1), vdupq_n_u32(1)));
- scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(data.val[3], 28), 1), vdupq_n_u32(1)));
- return scales;
+ return prepare_scales_8(data.val[1], data.val[3]);
+ }
+
+ static inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) {
+ b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]});
+ b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]});
+ apply_signs_2(b, signs, sidx);
}
inline void prepare(int /*i*/, int j) {
const uint8_t * idx = (const uint8_t *)(data.val + 2*j);
- bits.b1.val[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[ 0]], iq2xxs_grid[idx[ 1]]});
- bits.b1.val[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[ 2]], iq2xxs_grid[idx[ 3]]});
- bits.b1.val[2] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[ 4]], iq2xxs_grid[idx[ 5]]});
- bits.b1.val[3] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[ 6]], iq2xxs_grid[idx[ 7]]});
- bits.b2.val[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[ 8]], iq2xxs_grid[idx[ 9]]});
- bits.b2.val[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[10]], iq2xxs_grid[idx[11]]});
- bits.b2.val[2] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[12]], iq2xxs_grid[idx[13]]});
- bits.b2.val[3] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[14]], iq2xxs_grid[idx[15]]});
-
- //const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
- const uint64_t * signs = keven_signs;
const uint32_t * sidx = (const uint32_t *)(data.val + 2*j+1);
-
- bits.b1.val[0] = vmulq_s8(bits.b1.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx[0] >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx[0] >> 7) & 127)))));
- bits.b1.val[1] = vmulq_s8(bits.b1.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx[0] >> 14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx[0] >> 21) & 127)))));
- bits.b1.val[2] = vmulq_s8(bits.b1.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx[1] >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx[1] >> 7) & 127)))));
- bits.b1.val[3] = vmulq_s8(bits.b1.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx[1] >> 14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx[1] >> 21) & 127)))));
- bits.b2.val[0] = vmulq_s8(bits.b2.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx[2] >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx[2] >> 7) & 127)))));
- bits.b2.val[1] = vmulq_s8(bits.b2.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx[2] >> 14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx[2] >> 21) & 127)))));
- bits.b2.val[2] = vmulq_s8(bits.b2.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx[3] >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx[3] >> 7) & 127)))));
- bits.b2.val[3] = vmulq_s8(bits.b2.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx[3] >> 14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx[3] >> 21) & 127)))));
+ prepare2(bits.b1.val + 0, idx, keven_signs, sidx[0]); idx += 4;
+ prepare2(bits.b1.val + 2, idx, keven_signs, sidx[1]); idx += 4;
+ prepare2(bits.b2.val + 0, idx, keven_signs, sidx[2]); idx += 4;
+ prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]);
}
uint32x4x4_t data;
@@ -2290,31 +2291,21 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
d = 0.25f * GGML_FP16_TO_FP32(x[i].d);
gas = vld1q_u32_x2((const uint32_t *)(x[i].qs + QK_K/4));
- int32x4x2_t scales;
- scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(gas.val[0], 28), 1), vdupq_n_u32(1)));
- scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(gas.val[1], 28), 1), vdupq_n_u32(1)));
- return scales;
+ return prepare_scales_8(gas.val[0], gas.val[1]);
}
inline static void make2(const uint8_t * q3, uint32_t sidx, uint8x16_t * b) {
- const uint64_t * signs = keven_signs;
b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]});
b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]});
- auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127))));
- auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 21) & 127))));
- b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1));
- b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2));
+ apply_signs_2(b, keven_signs, sidx);
}
inline void prepare(int i, int j) {
const auto * q3 = x[i].qs + 32*j;
const auto * signs = (const uint32_t *)(gas.val + j);
- for (int k = 0; k < 2; ++k) {
- make2(q3, signs[k], bits.b1.val + 2*k); q3 += 8;
- }
- signs += 2;
- for (int k = 0; k < 2; ++k) {
- make2(q3, signs[k], bits.b2.val + 2*k); q3 += 8;
- }
+ make2(q3, signs[0], bits.b1.val + 0); q3 += 8;
+ make2(q3, signs[1], bits.b1.val + 2); q3 += 8;
+ make2(q3, signs[2], bits.b2.val + 0); q3 += 8;
+ make2(q3, signs[3], bits.b2.val + 2);
}
SimpleBits bits;