summaryrefslogtreecommitdiff
path: root/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r--iqk_mul_mat.cpp31
1 files changed, 28 insertions, 3 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 7c1afa39..aa364900 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1763,6 +1763,7 @@ struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
h.bits = vld1q_u8_x2(x[i].hmask);
+ mask = vdupq_n_u8(0x01);
const uint16_t * sc16 = (const uint16_t *)x[i].scales;
uint32_t aux0 = sc16[0] | (sc16[1] << 16);
uint32_t aux1 = sc16[2] | (sc16[3] << 16);
@@ -1771,19 +1772,43 @@ struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
- return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d);
+ auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
+ if (nrc > 1) {
+ return process_scales_mins_16(scales8, q8, acc, i, -4.f*d);
+ }
+ int16x8x2_t scales16;
+ scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
+ scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
+ return make_wider(scales16);
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
- h.apply(bits.b1, bits.b2, j == 0);
+ if (nrc > 1) {
+ h.apply(bits.b1, bits.b2, j == 0);
+ } else {
+ auto minus4 = vdupq_n_u8(0xfc);
+ auto zero = vdupq_n_u8(0);
+ bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ }
}
uint32_t aux32[4];
Q2bits bits;
- const uint8x16_t mhb = vdupq_n_u8(0x04);
+ uint8x16_t mask;
HighBit3 h;
float d;