summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp446
-rw-r--r--src/llama.cpp4
2 files changed, 189 insertions, 261 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index b6ff7ab7..ba6ad15d 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -6292,7 +6292,7 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
};
-template <int nrc_y, typename Dequantizer>
+template <typename Dequantizer, int nrc_y>
void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
@@ -7680,66 +7680,50 @@ IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint
}
template <int nrc_y>
-void mul_mat_iq4_nl_x4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_0_x4> q8(info);
- auto m4 = vdupq_n_u8(0xf);
- auto values = vld1q_s8(iq4k_values);
- int nb = n / QK4_NL;
- GGML_ASSERT(nb%4 == 0);
- int8x16_t qx[8];
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int k = 0; k < 4; ++k) {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
- auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
- prepare_iq4_nl_quants(values, m4, bits, qx);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
- auto sumi = interleaved_dotq(qx, y);
- auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
- acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
- }
-}
-
-template <int nrc_y>
void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
+ auto m3 = vdupq_n_u8(0x30);
+ auto m32 = vdupq_n_s8(-32);
auto values = vld1q_s8(iq4k_values);
int nbl = n / QK_K;
int8x16_t qx[8];
+ int8x16x2_t iscales;
+ float32x4x4_t scales;
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
- const uint32_t * scales_l = (const uint32_t *)iq4[ibl].scales_l;
- const uint32_t * scales_h = (const uint32_t *)iq4[ibl].scales_h;
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
- for (int ib = 0; ib < QK_K/32; ++ib) {
- auto ul = (scales_l[ib%4] >> 4*(ib/4)) & 0x0f0f0f0f;
- auto uh = (scales_h[ib%2] >> 2*(ib/2)) & 0x03030303;
- auto sl8 = vsub_s8(vreinterpret_s8_s32(vdup_n_s32(ul | (uh << 4))), vdup_n_s8(32));
- auto sl16 = vmovl_s8(sl8);
- auto sl32 = vmovl_s16(vget_low_s16(sl16));
- auto scales = vmulq_f32(d4, vcvtq_f32_s32(sl32));
- auto bits = vld1q_u8_x4(iq4[ibl].qs + 64*ib);
- prepare_iq4_nl_quants(values, m4, bits, qx);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+32*ib);
- auto sumi = interleaved_dotq(qx, y);
- auto d4d8 = vmulq_f32(scales, vdupq_n_f32(q8.scale(iy, ibl)));
- acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
+ if constexpr (nrc_y == 1) {
+ d4 = vmulq_f32(d4, vdupq_n_f32(q8.scale(0, ibl)));
+ }
+ auto sl = vld1q_u8(iq4[ibl].scales_l);
+ auto sh8 = vld1_u8(iq4[ibl].scales_h);
+ auto sh = vcombine_u8(sh8, vshr_n_u8(sh8, 2));
+ iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl, m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
+ iscales.val[1] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl, 4), vandq_u8(sh, m3)), m32);
+ for (int is = 0; is < 2; ++is) {
+ auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
+ auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
+ scales.val[0] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1))));
+ scales.val[1] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1))));
+ scales.val[2] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2))));
+ scales.val[3] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2))));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
+ prepare_iq4_nl_quants(values, m4, bits, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ if constexpr (nrc_y == 1) {
+ acc[iy] = vfmaq_f32(acc[iy], scales.val[ib], vcvtq_f32_s32(sumi));
+ } else {
+ auto d4d8 = vmulq_f32(scales.val[ib], vdupq_n_f32(q8.scale(iy, ibl)));
+ acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
+ }
+ }
}
}
}
@@ -7793,152 +7777,145 @@ void mul_mat_iq4_nl_x4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo&
}
}
-template <int nrc_y>
-void mul_mat_q4_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+template <typename Dequantizer, int nrc_y>
+void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
- auto m4 = vdupq_n_u8(0xf0);
- auto m88 = vdupq_n_u8(0x88);
- auto norm = vdupq_n_f32(1.f/16);
+ Dequantizer deq(vx, bx);
int nb = n / QK4_NL;
GGML_ASSERT(nb%4 == 0);
int8x16_t qx[8];
+ float d8[4*nrc_y];
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx);
+ deq.new_row(ix);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
+ }
for (int k = 0; k < 4; ++k) {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
- auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
- for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]);
- qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows
- qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19
- qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7
- qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23
- qx[4] = vandq_u8(bits.val[0], m4); // 8..11
- qx[5] = vandq_u8(bits.val[1], m4); // 24..27
- qx[6] = vandq_u8(bits.val[2], m4); // 12..15
- qx[7] = vandq_u8(bits.val[3], m4); // 28..31
+ auto scales = deq.prepare(ib4, k, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
- auto sumi = vdupq_n_s32(0);
- sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
- sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
- sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
- sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
- sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
- sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
- sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
- sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
- auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
+ auto sumi = interleaved_dotq(qx, y);
+ auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k]));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vmulq_f32(norm, acc[iy]));
+ info.store(ix, iy, deq.result(acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
-template <int nrc_y>
-void mul_mat_q5_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_0_x4> q8(info);
- auto m4 = vdupq_n_u8(0x0f);
- auto m5 = vdupq_n_u8(0x10);
- auto m16 = vdupq_n_s8(-16);
- int nb = n / QK5_0;
- GGML_ASSERT(nb%4 == 0);
- int8x16_t qx[8];
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int k = 0; k < 4; ++k) {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d));
- auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs);
- auto hbits = vld1q_u8(iq5[4*ib4+k].qh);
- qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
- qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
- qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
- qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23
- qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11
- qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27
- qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15
- qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
- auto sumi = vdupq_n_s32(0);
- sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
- sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
- sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
- sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
- sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
- sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
- sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
- sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
- auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
- acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
+struct IQ4_NL_R4_Dequantizer {
+ IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {}
+ inline void new_row(int ix) { iq4 = (const block_iq4_nl_x4 *)(cx + ix*bx); }
+ inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
+ auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
+ prepare_iq4_nl_quants(values, m4, bits, qx);
+ return scales;
+ }
+ inline float32x4_t result(float32x4_t acc) const {
+ return acc;
}
-}
-template <int nrc_y>
-void mul_mat_q6_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
- Q8<nrc_y, block_q8_0_x4> q8(info);
- auto m4 = vdupq_n_u8(0x0f);
- auto m6 = vdupq_n_u8(0x30);
- auto m32 = vdupq_n_s8(-32);
- int nb = n / QK6_0;
- GGML_ASSERT(nb%4 == 0);
- int8x16_t qx[8];
- float32x4_t acc[nrc_y] = {};
- for (int ix = 0; ix < nrc_x; ix += 4) {
- const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx);
- for (int ib4 = 0; ib4 < nb/4; ++ib4) {
- for (int k = 0; k < 4; ++k) {
- auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[4*ib4+k].d));
- auto lbits = vld1q_u8_x4(iq6[4*ib4+k].qs);
- auto hbits = vld1q_u8_x2(iq6[4*ib4+k].qh);
- qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3
- qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19
- qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7
- qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 2), m6), m32); // 20..23
- qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits.val[0], m6), m32); // 8..11
- qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(hbits.val[1], m6), m32); // 24..27
- qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits.val[0], 2), m6), m32); // 12..15
- qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits.val[1], 2), m6), m32); // 28..31
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
- auto sumi = vdupq_n_s32(0);
- sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
- sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
- sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
- sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
- sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
- sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
- sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
- sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
- auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
- acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
- }
- }
- }
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, acc[iy]);
- acc[iy] = vdupq_n_f32(0.f);
- }
+ const char * cx;
+ const size_t bx;
+ const block_iq4_nl_x4 * iq4;
+ const uint8x16_t m4 = vdupq_n_u8(0x0f);
+ const int8x16_t values;
+};
+
+struct Q4_0_R4_Dequantizer {
+ Q4_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
+ inline void new_row(int ix) { iq4 = (const block_iq4_nl_x4 *)(cx + ix*bx); }
+ inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
+ auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
+ for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]);
+ qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows
+ qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19
+ qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7
+ qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23
+ qx[4] = vandq_u8(bits.val[0], m4); // 8..11
+ qx[5] = vandq_u8(bits.val[1], m4); // 24..27
+ qx[6] = vandq_u8(bits.val[2], m4); // 12..15
+ qx[7] = vandq_u8(bits.val[3], m4); // 28..31
+ return scales;
+ }
+ inline float32x4_t result(float32x4_t acc) const {
+ return vmulq_f32(norm, acc);
}
-}
+
+ const char * cx;
+ const size_t bx;
+ const block_iq4_nl_x4 * iq4;
+ const uint8x16_t m4 = vdupq_n_u8(0xf0);
+ const uint8x16_t m88 = vdupq_n_u8(0x88);
+ const float32x4_t norm = vdupq_n_f32(1.f/16);
+};
+
+struct Q5_0_R4_Dequantizer {
+ Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
+ inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); }
+ inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d));
+ auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs);
+ auto hbits = vld1q_u8(iq5[4*ib4+k].qh);
+ qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
+ qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
+ qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
+ qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23
+ qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11
+ qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27
+ qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15
+ qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31
+ return scales;
+ }
+ inline float32x4_t result(float32x4_t acc) const {
+ return acc;
+ }
+
+ const char * cx;
+ const size_t bx;
+ const block_q5_0_r4 * iq5;
+ const uint8x16_t m4 = vdupq_n_u8(0x0f);
+ const uint8x16_t m5 = vdupq_n_u8(0x10);
+ const int8x16_t m16 = vdupq_n_s8(-16);
+};
+
+struct Q6_0_R4_Dequantizer {
+ Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
+ inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); }
+ inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[4*ib4+k].d));
+ auto lbits = vld1q_u8_x4(iq6[4*ib4+k].qs);
+ auto hbits = vld1q_u8_x2(iq6[4*ib4+k].qh);
+ qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3
+ qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19
+ qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7
+ qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 2), m6), m32); // 20..23
+ qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits.val[0], m6), m32); // 8..11
+ qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(hbits.val[1], m6), m32); // 24..27
+ qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits.val[0], 2), m6), m32); // 12..15
+ qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits.val[1], 2), m6), m32); // 28..31
+ return scales;
+ }
+ inline float32x4_t result(float32x4_t acc) const {
+ return acc;
+ }
+
+ const char * cx;
+ const size_t bx;
+ const block_q6_0_r4 * iq6;
+ const uint8x16_t m4 = vdupq_n_u8(0x0f);
+ const uint8x16_t m6 = vdupq_n_u8(0x30);
+ const int8x16_t m32 = vdupq_n_s8(-32);
+};
template <int nrc_y>
void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -7947,9 +7924,13 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
int nb = n / QK8_0;
GGML_ASSERT(nb%4 == 0);
float32x4_t acc[nrc_y] = {};
+ float d8[4*nrc_y];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
+ }
for (int k = 0; k < 4; ++k) {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[4*ib4+k].d));
auto qx1 = vld1q_s8_x4(iq8[4*ib4+k].qs);
@@ -7965,7 +7946,7 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
sumi = vdotq_laneq_s32(sumi, qx2.val[1], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx2.val[2], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx2.val[3], y.val[1], 3);
- auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
+ auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k]));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}
}
@@ -7977,38 +7958,37 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
+#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \
+ m.funcs[0] = func<Dequantizer, 1>;\
+ m.funcs[1] = func<Dequantizer, 2>;\
+ m.funcs[2] = func<Dequantizer, 3>;\
+ m.funcs[3] = func<Dequantizer, 4>;\
+ m.funcs[4] = func<Dequantizer, 5>;\
+ m.funcs[5] = func<Dequantizer, 6>;\
+ m.funcs[6] = func<Dequantizer, 7>;\
+ m.funcs[7] = func<Dequantizer, 8>;\
+
+#define SET_MUL_MAT_FUNCTIONS(m, func) \
+ m.funcs[0] = func<1>;\
+ m.funcs[1] = func<2>;\
+ m.funcs[2] = func<3>;\
+ m.funcs[3] = func<4>;\
+ m.funcs[4] = func<5>;\
+ m.funcs[5] = func<6>;\
+ m.funcs[6] = func<7>;\
+ m.funcs[7] = func<8>;\
+
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||
std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL> ||
std::is_same_v<Dequantizer, DequantizerQ60>) {
- m.funcs[0] = mul_mat_qX_0_q8_0<Dequantizer, 1>;
- m.funcs[1] = mul_mat_qX_0_q8_0<Dequantizer, 2>;
- m.funcs[2] = mul_mat_qX_0_q8_0<Dequantizer, 3>;
- m.funcs[3] = mul_mat_qX_0_q8_0<Dequantizer, 4>;
- m.funcs[4] = mul_mat_qX_0_q8_0<Dequantizer, 5>;
- m.funcs[5] = mul_mat_qX_0_q8_0<Dequantizer, 6>;
- m.funcs[6] = mul_mat_qX_0_q8_0<Dequantizer, 7>;
- m.funcs[7] = mul_mat_qX_0_q8_0<Dequantizer, 8>;
+ SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_0_q8_0, Dequantizer);
}
else if constexpr (std::is_same_v<Dequantizer, DequantizerQ41> || std::is_same_v<Dequantizer, DequantizerQ51>) {
- m.funcs[0] = mul_mat_qX_1_q8_1<Dequantizer, 1>;
- m.funcs[1] = mul_mat_qX_1_q8_1<Dequantizer, 2>;
- m.funcs[2] = mul_mat_qX_1_q8_1<Dequantizer, 3>;
- m.funcs[3] = mul_mat_qX_1_q8_1<Dequantizer, 4>;
- m.funcs[4] = mul_mat_qX_1_q8_1<Dequantizer, 5>;
- m.funcs[5] = mul_mat_qX_1_q8_1<Dequantizer, 6>;
- m.funcs[6] = mul_mat_qX_1_q8_1<Dequantizer, 7>;
- m.funcs[7] = mul_mat_qX_1_q8_1<Dequantizer, 8>;
+ SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_1_q8_1, Dequantizer);
}
else {
- m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>;
- m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>;
- m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>;
- m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>;
- m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>;
- m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>;
- m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>;
- m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>;
+ SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_K_q8_K_T, Dequantizer);
}
}
@@ -8097,25 +8077,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
MulMat::set_functions<DequantizerIQ3S>(m);
break;
case GGML_TYPE_IQ1_BN:
- m.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
- m.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
- m.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
- m.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
- m.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
- m.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
- m.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
- m.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1bn_q8_K64);
expected_Btype = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN:
- m.funcs[0] = mul_mat_iq2bn_q8_K64<1>;
- m.funcs[1] = mul_mat_iq2bn_q8_K64<2>;
- m.funcs[2] = mul_mat_iq2bn_q8_K64<3>;
- m.funcs[3] = mul_mat_iq2bn_q8_K64<4>;
- m.funcs[4] = mul_mat_iq2bn_q8_K64<5>;
- m.funcs[5] = mul_mat_iq2bn_q8_K64<6>;
- m.funcs[6] = mul_mat_iq2bn_q8_K64<7>;
- m.funcs[7] = mul_mat_iq2bn_q8_K64<8>;
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2bn_q8_K64);
expected_Btype = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN_R4:
@@ -8158,69 +8124,27 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
expected_Btype = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_IQ4_NL_X4:
- m.funcs[0] = mul_mat_iq4_nl_x4_q8_0_1;
- m.funcs[1] = mul_mat_iq4_nl_x4_q8_0<2>;
- m.funcs[2] = mul_mat_iq4_nl_x4_q8_0<3>;
- m.funcs[3] = mul_mat_iq4_nl_x4_q8_0<4>;
- m.funcs[4] = mul_mat_iq4_nl_x4_q8_0<5>;
- m.funcs[5] = mul_mat_iq4_nl_x4_q8_0<6>;
- m.funcs[6] = mul_mat_iq4_nl_x4_q8_0<7>;
- m.funcs[7] = mul_mat_iq4_nl_x4_q8_0<8>;
+ SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_IQ4_XS_R4:
- m.funcs[0] = mul_mat_iq4_nl_x4_q8_0_1;
- m.funcs[1] = mul_mat_iq4_xs_r4_q8_k<2>;
- m.funcs[2] = mul_mat_iq4_xs_r4_q8_k<3>;
- m.funcs[3] = mul_mat_iq4_xs_r4_q8_k<4>;
- m.funcs[4] = mul_mat_iq4_xs_r4_q8_k<5>;
- m.funcs[5] = mul_mat_iq4_xs_r4_q8_k<6>;
- m.funcs[6] = mul_mat_iq4_xs_r4_q8_k<7>;
- m.funcs[7] = mul_mat_iq4_xs_r4_q8_k<8>;
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_Q4_0_R4:
- m.funcs[0] = mul_mat_q4_0_r4_q8_0<1>;
- m.funcs[1] = mul_mat_q4_0_r4_q8_0<2>;
- m.funcs[2] = mul_mat_q4_0_r4_q8_0<3>;
- m.funcs[3] = mul_mat_q4_0_r4_q8_0<4>;
- m.funcs[4] = mul_mat_q4_0_r4_q8_0<5>;
- m.funcs[5] = mul_mat_q4_0_r4_q8_0<6>;
- m.funcs[6] = mul_mat_q4_0_r4_q8_0<7>;
- m.funcs[7] = mul_mat_q4_0_r4_q8_0<8>;
+ SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q4_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q5_0_R4:
- m.funcs[0] = mul_mat_q5_0_r4_q8_0<1>;
- m.funcs[1] = mul_mat_q5_0_r4_q8_0<2>;
- m.funcs[2] = mul_mat_q5_0_r4_q8_0<3>;
- m.funcs[3] = mul_mat_q5_0_r4_q8_0<4>;
- m.funcs[4] = mul_mat_q5_0_r4_q8_0<5>;
- m.funcs[5] = mul_mat_q5_0_r4_q8_0<6>;
- m.funcs[6] = mul_mat_q5_0_r4_q8_0<7>;
- m.funcs[7] = mul_mat_q5_0_r4_q8_0<8>;
+ SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q6_0_R4:
- m.funcs[0] = mul_mat_q6_0_r4_q8_0<1>;
- m.funcs[1] = mul_mat_q6_0_r4_q8_0<2>;
- m.funcs[2] = mul_mat_q6_0_r4_q8_0<3>;
- m.funcs[3] = mul_mat_q6_0_r4_q8_0<4>;
- m.funcs[4] = mul_mat_q6_0_r4_q8_0<5>;
- m.funcs[5] = mul_mat_q6_0_r4_q8_0<6>;
- m.funcs[6] = mul_mat_q6_0_r4_q8_0<7>;
- m.funcs[7] = mul_mat_q6_0_r4_q8_0<8>;
+ SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q8_0_R4:
- m.funcs[0] = mul_mat_q8_0_r4_q8_0<1>;
- m.funcs[1] = mul_mat_q8_0_r4_q8_0<2>;
- m.funcs[2] = mul_mat_q8_0_r4_q8_0<3>;
- m.funcs[3] = mul_mat_q8_0_r4_q8_0<4>;
- m.funcs[4] = mul_mat_q8_0_r4_q8_0<5>;
- m.funcs[5] = mul_mat_q8_0_r4_q8_0<6>;
- m.funcs[6] = mul_mat_q8_0_r4_q8_0<7>;
- m.funcs[7] = mul_mat_q8_0_r4_q8_0<8>;
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r4_q8_0);
expected_Btype = GGML_TYPE_Q8_0;
break;
default:
@@ -8437,7 +8361,7 @@ struct F16 {
#else
using Data = float16x8_t;
constexpr static int block_size = 8;
- constexpr static int num_registers = 32;
+ //constexpr static int num_registers = 32;
constexpr static int q_step = 8;
static inline Data zero() { return vdupq_n_f16(0); }
static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); }
diff --git a/src/llama.cpp b/src/llama.cpp
index ad76a7b8..0e1aadbd 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -16569,6 +16569,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0;
else chunk_size_multiplier = 4;
}
+ else if (new_type == GGML_TYPE_Q5_0_R4) {
+ if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q5_0;
+ else chunk_size_multiplier = 4;
+ }
else if (new_type == GGML_TYPE_Q6_0_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_0;
else chunk_size_multiplier = 4;