diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-05-28 12:10:52 +0200 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:49 +0300 |
commit | 4b27ade2fb983da8210bde47e2fd913b7d92a30a (patch) | |
tree | 3c723c6c01853f0dfb10e17382fd661ac259298f | |
parent | 221a2c38070040c679c56a7d4c598508d485a759 (diff) |
iqk_mul_mat: Arm implementation for iq3_s (llama.cpp version)
Here we get 3.65X (!) for PP-512 (53 t/s).
-rw-r--r-- | iqk_mul_mat.cpp | 103 |
1 files changed, 88 insertions, 15 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 08f2bd47..7c56f0ef 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -15,6 +15,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <cstring> #include <type_traits> #if defined __x86_64__ || defined __aarch64__ @@ -2217,6 +2218,14 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> { }; +inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16, const uint8x16_t& smask, const uint8x16_t& step, + const uint8x16_t& m1, uint8x16_t& shuffle) { + auto aux = vqtbl1q_u8(signs16, shuffle); + auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); + b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); + shuffle = vaddq_u8(shuffle, step); +} + struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> { DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} @@ -2227,13 +2236,6 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> { inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); return prepare_4bit_scales16(x[i].scales); - - //auto aux1 = vld1_u8(x[i].scales); - //auto aux2 = vshr_n_u8(aux1, 4); - //auto scales8 = vqtbl1q_u8(vandq_u8(vcombine_u8(aux1, aux2), vdupq_n_u8(0xf)), vreinterpretq_u8_u64(scale_shuffle)); - //scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(scales8, 1), vdupq_n_u8(1))); - //int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) }; - //return make_wider(scales16); } static inline void make4(const uint8x16_t& signs16, uint8x16_t& shuffle, const uint8_t * qs, const uint8_t * qh, @@ -2246,17 +2248,11 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> { aux32[1] &= 0x03000300; b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))), vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1])))); - auto aux1 = vqtbl1q_u8(signs16, shuffle); - auto s1 = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux1, smask), smask), m1)); - b[2*k+0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[2*k+0]), s1)); - shuffle = vaddq_u8(shuffle, step); + apply_signs_1(b+2*k+0, signs16, smask, step, m1, shuffle); b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))), vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3])))); - auto aux2 = vqtbl1q_u8(signs16, shuffle); - auto s2 = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux2, smask), smask), m1)); - b[2*k+1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[2*k+1]), s2)); - shuffle = vaddq_u8(shuffle, step); + apply_signs_1(b+2*k+1, signs16, smask, step, m1, shuffle); } } @@ -2315,6 +2311,80 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> { }; +struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> { + DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template <typename Q8> + inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = GGML_FP16_TO_FP32(x[i].d); + uint32_t scales32[2]; + std::memcpy(scales32, x[i].scales, 4); + scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; + scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; + auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7 + scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400))); + auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8)); + int32x4x2_t scales; + scales.val[0] = vmovl_s16(vget_low_s16(scales16)); + scales.val[1] = vmovl_s16(vget_high_s16(scales16)); + return scales; + } + + static inline void make2(const uint8x16_t& signs16, uint8x16_t& shuffle, const uint16x8_t& idx_l, uint8_t qh, + const uint8x16_t& smask, const uint8x16_t& step, const uint8x16_t& m1, const int8x16_t& hshift, uint8x16_t * b) { + auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256))); + const uint16_t * idx = (const uint16_t *)&vindex; + b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]}); + b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); + apply_signs_1(b+0, signs16, smask, step, m1, shuffle); + apply_signs_1(b+1, signs16, smask, step, m1, shuffle); + } + static inline void make4(const uint8x16_t& signs16, uint8x16_t& shuffle, const uint8_t * qs, const uint8_t * qh, + const uint8x16_t& smask, const uint8x16_t& step, const uint8x16_t& m1, const int8x16_t& hshift, uint8x16_t * b) { + auto idx_l = vld1q_u8(qs); + make2(signs16, shuffle, vmovl_u8(vget_low_u8 (idx_l)), qh[0], smask, step, m1, hshift, b+0); + make2(signs16, shuffle, vmovl_u8(vget_high_u8(idx_l)), qh[1], smask, step, m1, hshift, b+2); + //auto vindex = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[0]), hshift), vdupq_n_u16(256))); + //const uint16_t * idx = (const uint16_t *)&vindex; + //b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]}); + //b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); + //apply_signs_1(b+0, signs16, smask, step, m1, shuffle); + //apply_signs_1(b+1, signs16, smask, step, m1, shuffle); + //vindex = vorrq_u16(vmovl_u8(vget_high_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[1]), hshift), vdupq_n_u16(256))); + //b[2] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]}); + //b[3] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); + //apply_signs_1(b+2, signs16, smask, step, m1, shuffle); + //apply_signs_1(b+3, signs16, smask, step, m1, shuffle); + } + + inline void prepare(int i, int j) { + + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + const auto smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); + const auto m1 = vdupq_n_u8(1); + const auto step = vdupq_n_u8(2); + const auto hshift = vld1q_s16(k_shift); + + const auto * qs = x[i].qs + 32*j; + const auto * qh = x[i].qh + 4*j; + const auto signs16 = vld1q_u8(x[i].signs + 16*j); + + auto shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); + make4(signs16, shuffle, qs+ 0, qh+0, smask, step, m1, hshift, bits.b1.val); + make4(signs16, shuffle, qs+16, qh+2, smask, step, m1, hshift, bits.b2.val); + } + + SimpleBits bits; + uint32x4x2_t gas; + + float d; + +}; + template <int nrc_y, typename Dequantizer> void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -2872,6 +2942,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int / case GGML_TYPE_IQ3_XXS: MulMat::set_functions<DequantizerIQ3XXS>(m); break; + case GGML_TYPE_IQ3_S: + MulMat::set_functions<DequantizerIQ3S>(m); + break; case GGML_TYPE_Q4_0: MulMat::set_functions<DequantizerQ40>(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); |