From b51922530f0b80602c007d14a00d1ffccda04d28 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 27 May 2024 11:05:44 +0200 Subject: iqk_mul_mat: faster q3_K TG We get 31 t/s up from 26 t/s, but we need to treat PP differently from TG, else we get a ~10% drop in PP performance. --- iqk_mul_mat.cpp | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) (limited to 'iqk_mul_mat.cpp') diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 7c1afa39..aa364900 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1763,6 +1763,7 @@ struct DequantizerQ3K final : public BaseDequantizer { inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); h.bits = vld1q_u8_x2(x[i].hmask); + mask = vdupq_n_u8(0x01); const uint16_t * sc16 = (const uint16_t *)x[i].scales; uint32_t aux0 = sc16[0] | (sc16[1] << 16); uint32_t aux1 = sc16[2] | (sc16[3] << 16); @@ -1771,19 +1772,43 @@ struct DequantizerQ3K final : public BaseDequantizer { aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030); aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030); aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030); - return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d); + auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)); + if (nrc > 1) { + return process_scales_mins_16(scales8, q8, acc, i, -4.f*d); + } + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(scales8)); + scales16.val[1] = vmovl_s8(vget_high_s8(scales8)); + return make_wider(scales16); } inline void prepare(int i, int j) { bits.prepare(x[i].qs+32*j); - h.apply(bits.b1, bits.b2, j == 0); + if (nrc > 1) { + h.apply(bits.b1, bits.b2, j == 0); + } else { + auto minus4 = vdupq_n_u8(0xfc); + auto zero = vdupq_n_u8(0); + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + } } uint32_t aux32[4]; Q2bits bits; - const uint8x16_t mhb = vdupq_n_u8(0x04); + uint8x16_t mask; HighBit3 h; float d; -- cgit v1.2.3