summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-05-28 12:10:52 +0200
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:49 +0300
commit4b27ade2fb983da8210bde47e2fd913b7d92a30a (patch)
tree3c723c6c01853f0dfb10e17382fd661ac259298f
parent221a2c38070040c679c56a7d4c598508d485a759 (diff)
iqk_mul_mat: Arm implementation for iq3_s (llama.cpp version)
Here we get 3.65X (!) for PP-512 (53 t/s).
-rw-r--r--iqk_mul_mat.cpp103
1 files changed, 88 insertions, 15 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 08f2bd47..7c56f0ef 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -15,6 +15,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <cstring>
#include <type_traits>
#if defined __x86_64__ || defined __aarch64__
@@ -2217,6 +2218,14 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
};
+inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16, const uint8x16_t& smask, const uint8x16_t& step,
+ const uint8x16_t& m1, uint8x16_t& shuffle) {
+ auto aux = vqtbl1q_u8(signs16, shuffle);
+ auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));
+ b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));
+ shuffle = vaddq_u8(shuffle, step);
+}
+
struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
@@ -2227,13 +2236,6 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
return prepare_4bit_scales16(x[i].scales);
-
- //auto aux1 = vld1_u8(x[i].scales);
- //auto aux2 = vshr_n_u8(aux1, 4);
- //auto scales8 = vqtbl1q_u8(vandq_u8(vcombine_u8(aux1, aux2), vdupq_n_u8(0xf)), vreinterpretq_u8_u64(scale_shuffle));
- //scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(scales8, 1), vdupq_n_u8(1)));
- //int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };
- //return make_wider(scales16);
}
static inline void make4(const uint8x16_t& signs16, uint8x16_t& shuffle, const uint8_t * qs, const uint8_t * qh,
@@ -2246,17 +2248,11 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
aux32[1] &= 0x03000300;
b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))),
vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1]))));
- auto aux1 = vqtbl1q_u8(signs16, shuffle);
- auto s1 = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux1, smask), smask), m1));
- b[2*k+0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[2*k+0]), s1));
- shuffle = vaddq_u8(shuffle, step);
+ apply_signs_1(b+2*k+0, signs16, smask, step, m1, shuffle);
b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))),
vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3]))));
- auto aux2 = vqtbl1q_u8(signs16, shuffle);
- auto s2 = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux2, smask), smask), m1));
- b[2*k+1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[2*k+1]), s2));
- shuffle = vaddq_u8(shuffle, step);
+ apply_signs_1(b+2*k+1, signs16, smask, step, m1, shuffle);
}
}
@@ -2315,6 +2311,80 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
};
+struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
+ DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ uint32_t scales32[2];
+ std::memcpy(scales32, x[i].scales, 4);
+ scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
+ scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
+ auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7
+ scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400)));
+ auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8));
+ int32x4x2_t scales;
+ scales.val[0] = vmovl_s16(vget_low_s16(scales16));
+ scales.val[1] = vmovl_s16(vget_high_s16(scales16));
+ return scales;
+ }
+
+ static inline void make2(const uint8x16_t& signs16, uint8x16_t& shuffle, const uint16x8_t& idx_l, uint8_t qh,
+ const uint8x16_t& smask, const uint8x16_t& step, const uint8x16_t& m1, const int8x16_t& hshift, uint8x16_t * b) {
+ auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));
+ const uint16_t * idx = (const uint16_t *)&vindex;
+ b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
+ b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
+ apply_signs_1(b+0, signs16, smask, step, m1, shuffle);
+ apply_signs_1(b+1, signs16, smask, step, m1, shuffle);
+ }
+ static inline void make4(const uint8x16_t& signs16, uint8x16_t& shuffle, const uint8_t * qs, const uint8_t * qh,
+ const uint8x16_t& smask, const uint8x16_t& step, const uint8x16_t& m1, const int8x16_t& hshift, uint8x16_t * b) {
+ auto idx_l = vld1q_u8(qs);
+ make2(signs16, shuffle, vmovl_u8(vget_low_u8 (idx_l)), qh[0], smask, step, m1, hshift, b+0);
+ make2(signs16, shuffle, vmovl_u8(vget_high_u8(idx_l)), qh[1], smask, step, m1, hshift, b+2);
+ //auto vindex = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[0]), hshift), vdupq_n_u16(256)));
+ //const uint16_t * idx = (const uint16_t *)&vindex;
+ //b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
+ //b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
+ //apply_signs_1(b+0, signs16, smask, step, m1, shuffle);
+ //apply_signs_1(b+1, signs16, smask, step, m1, shuffle);
+ //vindex = vorrq_u16(vmovl_u8(vget_high_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[1]), hshift), vdupq_n_u16(256)));
+ //b[2] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
+ //b[3] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
+ //apply_signs_1(b+2, signs16, smask, step, m1, shuffle);
+ //apply_signs_1(b+3, signs16, smask, step, m1, shuffle);
+ }
+
+ inline void prepare(int i, int j) {
+
+ static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
+
+ const auto smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
+ const auto m1 = vdupq_n_u8(1);
+ const auto step = vdupq_n_u8(2);
+ const auto hshift = vld1q_s16(k_shift);
+
+ const auto * qs = x[i].qs + 32*j;
+ const auto * qh = x[i].qh + 4*j;
+ const auto signs16 = vld1q_u8(x[i].signs + 16*j);
+
+ auto shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1));
+ make4(signs16, shuffle, qs+ 0, qh+0, smask, step, m1, hshift, bits.b1.val);
+ make4(signs16, shuffle, qs+16, qh+2, smask, step, m1, hshift, bits.b2.val);
+ }
+
+ SimpleBits bits;
+ uint32x4x2_t gas;
+
+ float d;
+
+};
+
template <int nrc_y, typename Dequantizer>
void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -2872,6 +2942,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /
case GGML_TYPE_IQ3_XXS:
MulMat::set_functions<DequantizerIQ3XXS>(m);
break;
+ case GGML_TYPE_IQ3_S:
+ MulMat::set_functions<DequantizerIQ3S>(m);
+ break;
case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);