summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp108
1 files changed, 104 insertions, 4 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 13e6420b..511eea01 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -4941,6 +4941,26 @@ inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {
return vpaddq_s32(p1234, vpaddq_s32(p56, p78));
}
+inline int32x4x2_t sum_4_blocks(const int8x16_t * b1, const int8x16_t * b2, const int8_t * qs) {
+ auto q8b = vld1q_s8_x2(qs + 0);
+ auto p12_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q8b.val[0]), b1[1], q8b.val[1]);
+ auto p12_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q8b.val[0]), b2[1], q8b.val[1]);
+ q8b = vld1q_s8_x2(qs + 32);
+ auto p34_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q8b.val[0]), b1[3], q8b.val[1]);
+ auto p34_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q8b.val[0]), b2[3], q8b.val[1]);
+ auto p1234_1 = vpaddq_s32(p12_1, p34_1);
+ auto p1234_2 = vpaddq_s32(p12_2, p34_2);
+ q8b = vld1q_s8_x2(qs + 64);
+ auto p56_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[4], q8b.val[0]), b1[5], q8b.val[1]);
+ auto p56_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[4], q8b.val[0]), b2[5], q8b.val[1]);
+ q8b = vld1q_s8_x2(qs + 96);
+ auto p78_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[6], q8b.val[0]), b1[7], q8b.val[1]);
+ auto p78_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[6], q8b.val[0]), b2[7], q8b.val[1]);
+ auto p5678_1 = vpaddq_s32(p56_1, p78_1);
+ auto p5678_2 = vpaddq_s32(p56_2, p78_2);
+ return { vpaddq_s32(p1234_1, p5678_1), vpaddq_s32(p1234_2, p5678_2)};
+}
+
template <int nrc> struct Q80 {
constexpr static int nrc_y = nrc;
@@ -4969,6 +4989,17 @@ template <int nrc> struct Q80 {
}
template <typename Dequantizer>
+ inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * /*acc*/) const {
+ auto qx_scales_1 = deq1.new_block(i);
+ auto qx_scales_2 = deq2.new_block(i);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8_scales = load_scales(iy, i);
+ sc16[iy ] = vmul_f16(qx_scales_1, q8_scales);
+ sc16[iy+nrc_y] = vmul_f16(qx_scales_2, q8_scales);
+ }
+ }
+
+ template <typename Dequantizer>
inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {
deq.prepare1(i);
float d = GGML_FP16_TO_FP32(deq.x[i].d);
@@ -5012,6 +5043,23 @@ template <int nrc> struct Q81 {
}
template <typename Dequantizer>
+ inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * acc) const {
+ auto qx_scales_1 = deq1.new_block(i);
+ auto qx_scales_2 = deq2.new_block(i);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8_scales = load_scales(iy, i);
+ auto q8_scales_l = vget_low_f16(q8_scales);
+ auto q8_scales_h = vget_high_f16(q8_scales);
+ auto m1 = vmul_f16(vget_high_f16(qx_scales_1), q8_scales_h);
+ auto m2 = vmul_f16(vget_high_f16(qx_scales_2), q8_scales_h);
+ acc[iy ] = vaddq_f32(acc[iy ], vcvt_f32_f16(m1));
+ acc[iy+nrc_y ] = vaddq_f32(acc[iy+nrc_y], vcvt_f32_f16(m2));
+ sc16[iy ] = vmul_f16(vget_low_f16(qx_scales_1), q8_scales_l);
+ sc16[iy+nrc_y] = vmul_f16(vget_low_f16(qx_scales_2), q8_scales_l);
+ }
+ }
+
+ template <typename Dequantizer>
inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {
deq.prepare1(i);
float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m);
@@ -5236,6 +5284,17 @@ inline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc1
}
template <typename Dequantizer, typename Q8>
+inline void sum_4(int i, Dequantizer& deq1, Dequantizer& deq2, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto pall = sum_4_blocks(deq1.bits.b, deq2.bits.b, q8.quant_data(iy, i));
+ auto scale1 = vcvt_f32_f16(sc16[iy]);
+ auto scale2 = vcvt_f32_f16(sc16[iy+Q8::nrc_y]);
+ acc[iy] = vmlaq_f32(acc[iy], scale1, vcvtq_f32_s32(pall.val[0]));
+ acc[iy+Q8::nrc_y] = vmlaq_f32(acc[iy+Q8::nrc_y], scale2, vcvtq_f32_s32(pall.val[1]));
+ }
+}
+
+template <typename Dequantizer, typename Q8>
inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) {
const int nb = n / QK4_1;
@@ -5263,6 +5322,35 @@ inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& i
}
template <typename Dequantizer, typename Q8>
+inline void mul_mat_qX_Y_q8_Y_IK(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK4_1;
+
+ float16x4_t sc16[2*Q8::nrc_y];
+ float32x4_t acc[2*Q8::nrc_y];
+
+ for (int ix = 0; ix < nrc_x; ix += 2) {
+
+ deq1.new_row(ix+0);
+ deq2.new_row(ix+1);
+
+ for (int iy = 0; iy < 2*Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
+
+ for (int i = 0; i < nb/4; ++i) {
+ q8.process_scales(i, deq1, deq2, sc16, acc);
+ sum_4(i, deq1, deq2, q8, sc16, acc);
+ }
+ //for (int i = 4*(nb/4); i < nb; ++i) {
+ // q8.process_1_block(i, deq, acc);
+ //}
+
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ info.store(ix+0, iy, vaddvq_f32(acc[iy]));
+ info.store(ix+1, iy, vaddvq_f32(acc[iy+Q8::nrc_y]));
+ }
+ }
+}
+
+template <typename Dequantizer, typename Q8>
inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) {
const int nb = n / QK4_1;
@@ -5300,8 +5388,15 @@ static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo&
Dequantizer deq1(vx, bx), deq2(vx, bx);
mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);
} else {
- Dequantizer deq(vx, bx);
- mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);
+ if (nrc_x%2 == 0) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x);
+ } else {
+ Dequantizer deq(vx, bx);
+ mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);
+ }
+ //Dequantizer deq(vx, bx);
+ //mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);
}
}
@@ -5312,8 +5407,13 @@ static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo&
Dequantizer deq1(vx, bx), deq2(vx, bx);
mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);
} else {
- Dequantizer deq(vx, bx);
- mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);
+ if (nrc_x%2 == 0) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x);
+ } else {
+ Dequantizer deq(vx, bx);
+ mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);
+ }
}
}