diff options
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 108 |
1 files changed, 104 insertions, 4 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 13e6420b..511eea01 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -4941,6 +4941,26 @@ inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); } +inline int32x4x2_t sum_4_blocks(const int8x16_t * b1, const int8x16_t * b2, const int8_t * qs) { + auto q8b = vld1q_s8_x2(qs + 0); + auto p12_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q8b.val[0]), b1[1], q8b.val[1]); + auto p12_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q8b.val[0]), b2[1], q8b.val[1]); + q8b = vld1q_s8_x2(qs + 32); + auto p34_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q8b.val[0]), b1[3], q8b.val[1]); + auto p34_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q8b.val[0]), b2[3], q8b.val[1]); + auto p1234_1 = vpaddq_s32(p12_1, p34_1); + auto p1234_2 = vpaddq_s32(p12_2, p34_2); + q8b = vld1q_s8_x2(qs + 64); + auto p56_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[4], q8b.val[0]), b1[5], q8b.val[1]); + auto p56_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[4], q8b.val[0]), b2[5], q8b.val[1]); + q8b = vld1q_s8_x2(qs + 96); + auto p78_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[6], q8b.val[0]), b1[7], q8b.val[1]); + auto p78_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[6], q8b.val[0]), b2[7], q8b.val[1]); + auto p5678_1 = vpaddq_s32(p56_1, p78_1); + auto p5678_2 = vpaddq_s32(p56_2, p78_2); + return { vpaddq_s32(p1234_1, p5678_1), vpaddq_s32(p1234_2, p5678_2)}; +} + template <int nrc> struct Q80 { constexpr static int nrc_y = nrc; @@ -4969,6 +4989,17 @@ template <int nrc> struct Q80 { } template <typename Dequantizer> + inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * /*acc*/) const { + auto qx_scales_1 = deq1.new_block(i); + auto qx_scales_2 = deq2.new_block(i); + for (int iy = 0; iy < nrc; ++iy) { + auto q8_scales = load_scales(iy, i); + sc16[iy ] = vmul_f16(qx_scales_1, q8_scales); + sc16[iy+nrc_y] = vmul_f16(qx_scales_2, q8_scales); + } + } + + template <typename Dequantizer> inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { deq.prepare1(i); float d = GGML_FP16_TO_FP32(deq.x[i].d); @@ -5012,6 +5043,23 @@ template <int nrc> struct Q81 { } template <typename Dequantizer> + inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * acc) const { + auto qx_scales_1 = deq1.new_block(i); + auto qx_scales_2 = deq2.new_block(i); + for (int iy = 0; iy < nrc; ++iy) { + auto q8_scales = load_scales(iy, i); + auto q8_scales_l = vget_low_f16(q8_scales); + auto q8_scales_h = vget_high_f16(q8_scales); + auto m1 = vmul_f16(vget_high_f16(qx_scales_1), q8_scales_h); + auto m2 = vmul_f16(vget_high_f16(qx_scales_2), q8_scales_h); + acc[iy ] = vaddq_f32(acc[iy ], vcvt_f32_f16(m1)); + acc[iy+nrc_y ] = vaddq_f32(acc[iy+nrc_y], vcvt_f32_f16(m2)); + sc16[iy ] = vmul_f16(vget_low_f16(qx_scales_1), q8_scales_l); + sc16[iy+nrc_y] = vmul_f16(vget_low_f16(qx_scales_2), q8_scales_l); + } + } + + template <typename Dequantizer> inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { deq.prepare1(i); float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m); @@ -5236,6 +5284,17 @@ inline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc1 } template <typename Dequantizer, typename Q8> +inline void sum_4(int i, Dequantizer& deq1, Dequantizer& deq2, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto pall = sum_4_blocks(deq1.bits.b, deq2.bits.b, q8.quant_data(iy, i)); + auto scale1 = vcvt_f32_f16(sc16[iy]); + auto scale2 = vcvt_f32_f16(sc16[iy+Q8::nrc_y]); + acc[iy] = vmlaq_f32(acc[iy], scale1, vcvtq_f32_s32(pall.val[0])); + acc[iy+Q8::nrc_y] = vmlaq_f32(acc[iy+Q8::nrc_y], scale2, vcvtq_f32_s32(pall.val[1])); + } +} + +template <typename Dequantizer, typename Q8> inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) { const int nb = n / QK4_1; @@ -5263,6 +5322,35 @@ inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& i } template <typename Dequantizer, typename Q8> +inline void mul_mat_qX_Y_q8_Y_IK(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) { + const int nb = n / QK4_1; + + float16x4_t sc16[2*Q8::nrc_y]; + float32x4_t acc[2*Q8::nrc_y]; + + for (int ix = 0; ix < nrc_x; ix += 2) { + + deq1.new_row(ix+0); + deq2.new_row(ix+1); + + for (int iy = 0; iy < 2*Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb/4; ++i) { + q8.process_scales(i, deq1, deq2, sc16, acc); + sum_4(i, deq1, deq2, q8, sc16, acc); + } + //for (int i = 4*(nb/4); i < nb; ++i) { + // q8.process_1_block(i, deq, acc); + //} + + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + info.store(ix+0, iy, vaddvq_f32(acc[iy])); + info.store(ix+1, iy, vaddvq_f32(acc[iy+Q8::nrc_y])); + } + } +} + +template <typename Dequantizer, typename Q8> inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) { const int nb = n / QK4_1; @@ -5300,8 +5388,15 @@ static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } else { - Dequantizer deq(vx, bx); - mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); + if (nrc_x%2 == 0) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x); + } else { + Dequantizer deq(vx, bx); + mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); + } + //Dequantizer deq(vx, bx); + //mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); } } @@ -5312,8 +5407,13 @@ static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } else { - Dequantizer deq(vx, bx); - mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); + if (nrc_x%2 == 0) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x); + } else { + Dequantizer deq(vx, bx); + mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); + } } } |