summaryrefslogtreecommitdiff
path: root/ggml
diff options
context:
space:
mode:
Diffstat (limited to 'ggml')
-rw-r--r--ggml/src/iqk/iqk_gemm_1bit.cpp150
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp2
2 files changed, 116 insertions, 36 deletions
diff --git a/ggml/src/iqk/iqk_gemm_1bit.cpp b/ggml/src/iqk/iqk_gemm_1bit.cpp
index ece76f3c..4ec4f44f 100644
--- a/ggml/src/iqk/iqk_gemm_1bit.cpp
+++ b/ggml/src/iqk/iqk_gemm_1bit.cpp
@@ -9,7 +9,6 @@
namespace {
-#ifdef __AVX2__
static const uint64_t iq1s_grid_us[2048] = {
0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200,
0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000,
@@ -524,8 +523,8 @@ static const uint64_t iq1s_grid_us[2048] = {
0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101,
0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202,
};
-#else
-static const uint32_t iq1s_grid_us[2048] = {
+#ifdef __aarch64__
+static const uint32_t iq1s_grid_us_neon[2048] = {
0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,
@@ -2336,22 +2335,22 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)));
signs = vadd_s16(vdup_n_s16(-8), signs);
auto delta4 = vmulq_f32(vdupq_n_f32(0.125f), vcvtq_f32_s32(vmull_s16(signs, scales4)));
- qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
- qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
- qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]});
- qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
+ qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
+ qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
+ qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]});
+ qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
qx[1] = vandq_u8(vshrq_n_u8(qx[0], 4), mask); qx[0] = vandq_u8(qx[0], mask);
qx[3] = vandq_u8(vshrq_n_u8(qx[2], 4), mask); qx[2] = vandq_u8(qx[2], mask);
qx[5] = vandq_u8(vshrq_n_u8(qx[4], 4), mask); qx[4] = vandq_u8(qx[4], mask);
@@ -2409,22 +2408,22 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI
auto idxh = uint32x4_t{qh[0], qh[0] >> 4, qh[1], qh[1] >> 4};
auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(idxh, ms), ms), vdupq_n_u8(1)));
signs = vaddq_s8(signs, vdupq_n_s8(-8));
- qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
- qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]});
- qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]});
- qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)],
- iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]});
+ qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
+ qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]});
+ qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]});
+ qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)],
+ iq1s_grid_us_neon[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]});
auto shuffle = shuffle0;
for (int j = 0; j < 4; ++j) {
auto s = vqtbl1q_s8(signs, shuffle);
@@ -2583,6 +2582,81 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info,
}
}
+template <int nrc_y>
+void mul_mat_iq1_m_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ int8x16x2_t qx[8];
+ int32x4x4_t scales;
+ float32x4_t acc[nrc_y] = {};
+ uint8x16x2_t scale_shuffle = {vreinterpretq_u8_u64(uint64x2_t{0x0100010001000100, 0x0302030203020302}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0504050405040504, 0x0706070607060706})};
+ uint64x2x2_t delta_mask = {uint64x2_t{0x0008, 0x0080}, uint64x2_t{0x0800, 0x8000}};
+ iq1m_scale_t block_scale;
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto iq1m = (const block_iq1_m *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < n/QK_K; ++ibl) {
+ const uint16_t * sc = (const uint16_t *)iq1m[ibl].scales; // 4 x uint16_t, each containing 4 scales
+ block_scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ float d = GGML_FP16_TO_FP32(block_scale.f16);
+ auto qs = iq1m[ibl].qs;
+ auto qh = iq1m[ibl].qh;
+ auto aux8 = vld1_u8(iq1m[ibl].scales);
+ auto aux16 = vcombine_u8(aux8, aux8);
+ uint16x8x2_t sc16 = { vreinterpretq_u16_u8(vqtbl1q_u8(aux16, scale_shuffle.val[0])), vreinterpretq_u16_u8(vqtbl1q_u8(aux16, scale_shuffle.val[1])) };
+ sc16.val[0] = vmulq_u16(vandq_u16(sc16.val[0], vdupq_n_u64(0x0e0001c000380007)), vdupq_n_u64(0x0001000800400200));
+ sc16.val[1] = vmulq_u16(vandq_u16(sc16.val[1], vdupq_n_u64(0x0e0001c000380007)), vdupq_n_u64(0x0001000800400200));
+ sc16.val[0] = vaddq_u16(vshrq_n_u16(sc16.val[0], 8), vdupq_n_u16(1));
+ sc16.val[1] = vaddq_u16(vshrq_n_u16(sc16.val[1], 8), vdupq_n_u16(1));
+ scales.val[0] = vmovl_s16(vget_low_s16 (sc16.val[0]));
+ scales.val[1] = vmovl_s16(vget_high_s16(sc16.val[0]));
+ scales.val[2] = vmovl_s16(vget_low_s16 (sc16.val[1]));
+ scales.val[3] = vmovl_s16(vget_high_s16(sc16.val[1]));
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ qx[2*ib64+0] = {vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)],
+ iq1s_grid_us[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)]}),
+ vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
+ iq1s_grid_us[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)]})};
+ qx[2*ib64+1] = {vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)],
+ iq1s_grid_us[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)]}),
+ vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
+ iq1s_grid_us[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)]})};
+ auto qh16 = (const uint16_t *)qh;
+ auto h1 = vdupq_n_u64(qh16[0]);
+ auto h2 = vdupq_n_u64(qh16[1]);
+ auto delta1 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h1, delta_mask.val[0]), delta_mask.val[0])), vdupq_n_s8(1)));
+ auto delta2 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h1, delta_mask.val[1]), delta_mask.val[1])), vdupq_n_s8(1)));
+ qx[2*ib64+0].val[0] = vsubq_s8(vshlq_n_s8(qx[2*ib64+0].val[0], 3), delta1);
+ qx[2*ib64+0].val[1] = vsubq_s8(vshlq_n_s8(qx[2*ib64+0].val[1], 3), delta2);
+ delta1 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h2, delta_mask.val[0]), delta_mask.val[0])), vdupq_n_s8(1)));
+ delta2 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h2, delta_mask.val[1]), delta_mask.val[1])), vdupq_n_s8(1)));
+ qx[2*ib64+1].val[0] = vsubq_s8(vshlq_n_s8(qx[2*ib64+1].val[0], 3), delta1);
+ qx[2*ib64+1].val[1] = vsubq_s8(vshlq_n_s8(qx[2*ib64+1].val[1], 3), delta2);
+ qs += 8;
+ qh += 4;
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = vdupq_n_s32(0);
+ for (int j = 0; j < 4; ++j) {
+ auto y1 = q8.load_quants(iy, ibl, 2*j+0);
+ auto dot1 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+0].val[0], y1.val[0]);
+ auto dot2 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+0].val[1], y1.val[1]);
+ auto y2 = q8.load_quants(iy, ibl, 2*j+1);
+ auto dot3 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+1].val[0], y2.val[0]);
+ auto dot4 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+1].val[1], y2.val[1]);
+ auto dot = vpaddq_s32(vpaddq_s32(dot1, dot2), vpaddq_s32(dot3, dot4));
+ sumi = vmlaq_s32(sumi, dot, scales.val[j]);
+ }
+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy]));
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
inline float convert_to_q8_k_r8(float d0, const int8x16x2_t * qx, const int8_t * scales, uint32_t * block, uint32_t * q8_k) {
auto max_i16 = vdupq_n_u16(0);
int16x8x4_t q[8];
@@ -2774,6 +2848,12 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
func16 = mul_mat_iq1_s_r4_q8_1<16>;
expected_Btype = GGML_TYPE_Q8_K128;
break;
+ case GGML_TYPE_IQ1_M:
+ if (ne00%QK_K != 0) return false;
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_q8_K, funcs);
+ func16 = mul_mat_iq1_m_q8_K<16>;
+ expected_Btype = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_IQ1_M_R4:
if (ne00%128 != 0) return false;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_r4_q8_0, funcs);
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 75ca87df..d7a5c1d8 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -279,7 +279,7 @@ struct MulMat {
case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q6_K : return nrc_y >= 64 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
- case GGML_TYPE_IQ1_M : return nrc_y >= 8 ? GGML_TYPE_Q8_K_R8 : type;
+ case GGML_TYPE_IQ1_M : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;