summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_gemm_legacy_quants.cpp')
-rw-r--r--ggml/src/iqk/iqk_gemm_legacy_quants.cpp249
1 files changed, 235 insertions, 14 deletions
diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
index 312c556e..ab6eb130 100644
--- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
+++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp
@@ -2782,21 +2782,239 @@ void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
+typedef struct {
+ ggml_half d[16];
+ int8_t qs[256];
+} block_q8_1_r8;
+
+template <int nrc_y>
+void mul_mat_q8_1_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_1_x4> q8(info);
+ int nb = n / QK8_0;
+ float32x4_t acc[2*nrc_y] = {};
+ int8x16_t qx[16];
+ float d8[8*nrc_y];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q8_1_r8 * iq8 = (const block_q8_1_r8 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ vst1q_f32(d8+8*iy+0, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d+0)));
+ vst1q_f32(d8+8*iy+4, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d+4)));
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d);
+ auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
+ auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
+ auto m16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d+8);
+ auto m1 = vcvt_f32_f16(vget_low_f16 (m16));
+ auto m2 = vcvt_f32_f16(vget_high_f16(m16));
+ for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
+ int32x4_t sumi1, sumi2;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
+ auto dy = vdupq_n_f32(d8[8*iy+k]);
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
+ auto my = vdupq_n_f32(d8[8*iy+k+4]);
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], m1, my);
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], m2, my);
+ }
+ }
+ }
+ for (int ib = 4*(nb/4); ib < nb; ++ib) {
+ auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
+ auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
+ auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
+ auto m16 = vld1q_f16((const float16_t *)iq8[ib].d+8);
+ auto m1 = vcvt_f32_f16(vget_low_f16 (m16));
+ auto m2 = vcvt_f32_f16(vget_high_f16(m16));
+ for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
+ int32x4_t sumi1, sumi2;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto qy = (const block_q8_1 *)q8.y[iy];
+ qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
+ auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
+ auto my = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].s));
+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], m1, my);
+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], m2, my);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix+0, iy, acc[2*iy+0]);
+ info.store(ix+4, iy, acc[2*iy+1]);
+ acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
+ }
+ }
}
-bool iqk_convert_legacy_quants_q8_r8([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) {
- return false;
- //switch (type) {
- // case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8<block_q4_0, Q4_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
- // case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8<block_q4_1, Q4_1_Dequantizer>(n, vx, bx, vy, nrc_x); break;
- // case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, Q5_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
- // case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, Q5_1_Dequantizer<block_q5_1>>(n, vx, bx, vy, nrc_x); break;
- // case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, Q6_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
- // case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, IQ4_NL0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
- // case GGML_TYPE_Q8_0 : iqk_convert_q80_q80_r8(n, vx, bx, vy, nrc_x); break;
- // default: return false;
- //}
- //return true;
+struct DeqQ40 {
+ const int8x16_t m8 = vdupq_n_s8(-8);
+ const uint8x16_t ml = vdupq_n_s8(0xf);
+ inline int8x16x2_t dequant(const block_q4_0& x) const {
+ auto bits = vld1q_u8(x.qs);
+ return { vaddq_s8(vreinterpretq_s8_u8(vandq_u8(bits, ml)), m8), vaddq_s8(vreinterpretq_s8_u8(vshrq_n_u8(bits, 4)), m8) };
+ }
+};
+
+struct DeqQ41 {
+ const uint8x16_t ml = vdupq_n_s8(0xf);
+ inline int8x16x2_t dequant(const block_q4_1& x) const {
+ auto bits = vld1q_u8(x.qs);
+ return { vreinterpretq_s8_u8(vandq_u8(bits, ml)), vreinterpretq_s8_u8(vshrq_n_u8(bits, 4)) };
+ }
+};
+
+struct DeqIQ4NL {
+ const int8x16_t mt = load_values();
+ const uint8x16_t ml = vdupq_n_s8(0xf);
+ inline int8x16x2_t dequant(const block_iq4_nl& x) const {
+ auto bits = vld1q_u8(x.qs);
+ return { vqtbl1q_s8(mt, vandq_u8(bits, ml)), vqtbl1q_s8(mt, vshrq_n_u8(bits, 4)) };
+ }
+ static inline int8x16_t load_values() { return vld1q_s8(iq4k_values); }
+};
+
+struct DeqQ50 {
+
+ inline int8x16x2_t dequant(const block_q5_0& x) const {
+ int8x16x2_t r;
+ bits.prepare1(x.qs, r.val);
+ auto qh = x.qh;
+ r.val[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0))));
+ r.val[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2))));
+ return r;
+ }
+
+ Q4LegacyBits bits;
+ HighBit5Legacy hbits;
+ const uint8x16_t mh = vdupq_n_u8(0xf0);
+};
+
+struct DeqQ51 {
+
+ inline int8x16x2_t dequant(const block_q5_1& x) const {
+ int8x16x2_t r;
+ bits.prepare1(x.qs, r.val);
+ auto qh = x.qh;
+ r.val[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[0]), vandq_u8(mh, hbits.to_bytes(qh+0))));
+ r.val[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[1]), vandq_u8(mh, hbits.to_bytes(qh+2))));
+ return r;
+ }
+
+ Q4LegacyBits bits;
+ HighBit5Legacy hbits;
+ const uint8x16_t mh = vdupq_n_u8(0x10);
+};
+
+struct DeqQ60 {
+
+ inline int8x16x2_t dequant(const block_q6_0& x) const {
+ int8x16x2_t r;
+ bits.prepare1(x.qs, r.val);
+ auto qh8 = vld1_u8(x.qh);
+ auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
+ r.val[0] = vaddq_s8(vorrq_u8(r.val[0], vandq_u8(qh, hmask)), m32);
+ r.val[1] = vaddq_s8(vorrq_u8(r.val[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32);
+ return r;
+ }
+
+ Q4LegacyBits bits;
+ const int8x16_t m32 = vdupq_n_s8(-32);
+ const uint8x16_t hmask = vdupq_n_u8(0x30);
+};
+
+struct DeqQ80 {
+ inline int8x16x2_t dequant(const block_q8_0& x) const {
+ return vld1q_s8_x2(x.qs);
+ }
+};
+
+template <typename Block, typename Dequantizer>
+void iqk_convert_qX_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK4_0 == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ const int nb = n/QK8_0;
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ const Block * x8[8];
+
+ uint32_t block[8];
+
+ Dequantizer deq;
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+
+ for (int k = 0; k < 8; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx);
+
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ y[i].d[k] = x8[k][i].d;
+ vst1q_s8_x2((int8_t *)block, deq.dequant(x8[k][i]));
+ auto qs = (uint32_t *)y[i].qs;
+ for (int l = 0; l < 4; ++l) {
+ qs[8*l + k + 0] = block[l + 0];
+ qs[8*l + k + 32] = block[l + 4];
+ }
+ }
+ }
+ y += nb;
+ }
+}
+
+template <typename Block, typename Dequantizer>
+void iqk_convert_qX_1_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK4_0 == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ const int nb = n/QK8_0;
+
+ block_q8_1_r8 * y = (block_q8_1_r8 *)vy;
+
+ const Block * x8[8];
+
+ uint32_t block[8];
+
+ Dequantizer deq;
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+
+ for (int k = 0; k < 8; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx);
+
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ y[i].d[k+0] = x8[k][i].d;
+ y[i].d[k+8] = x8[k][i].m;
+ vst1q_s8_x2((int8_t *)block, deq.dequant(x8[k][i]));
+ auto qs = (uint32_t *)y[i].qs;
+ for (int l = 0; l < 4; ++l) {
+ qs[8*l + k + 0] = block[l + 0];
+ qs[8*l + k + 32] = block[l + 4];
+ }
+ }
+ }
+ y += nb;
+ }
+}
+
+}
+
+bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ switch (type) {
+ case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8<block_q4_0, DeqQ40>(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8<block_q4_1, DeqQ41>(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, DeqQ50>(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, DeqQ51>(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, DeqQ60>(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, DeqIQ4NL>(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_Q8_0 : iqk_convert_qX_q80_r8<block_q8_0, DeqQ80>(n, vx, bx, vy, nrc_x); break;
+ default: return false;
+ }
+ return true;
}
bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
@@ -2804,7 +3022,7 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
if (ne00%QK8_0 != 0) return false;
auto etypeA = ggml_type(typeA);
- auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4;
+ auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 || etypeA == GGML_TYPE_Q8_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4;
if (ggml_type(typeB) != expected_typeB) return false;
func16 = nullptr;
@@ -2843,6 +3061,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
case GGML_TYPE_Q8_0_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_0, kernels);
break;
+ case GGML_TYPE_Q8_1:
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_1_r8_q8_1, kernels);
+ break;
case GGML_TYPE_IQ4_NL_R4:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer, kernels);
break;