diff options
Diffstat (limited to 'ggml')
-rw-r--r-- | ggml/src/iqk/iqk_gemm_1bit.cpp | 150 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 2 |
2 files changed, 116 insertions, 36 deletions
diff --git a/ggml/src/iqk/iqk_gemm_1bit.cpp b/ggml/src/iqk/iqk_gemm_1bit.cpp index ece76f3c..4ec4f44f 100644 --- a/ggml/src/iqk/iqk_gemm_1bit.cpp +++ b/ggml/src/iqk/iqk_gemm_1bit.cpp @@ -9,7 +9,6 @@ namespace { -#ifdef __AVX2__ static const uint64_t iq1s_grid_us[2048] = { 0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200, 0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000, @@ -524,8 +523,8 @@ static const uint64_t iq1s_grid_us[2048] = { 0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101, 0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202, }; -#else -static const uint32_t iq1s_grid_us[2048] = { +#ifdef __aarch64__ +static const uint32_t iq1s_grid_us_neon[2048] = { 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, @@ -2336,22 +2335,22 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1))); signs = vadd_s16(vdup_n_s16(-8), signs); auto delta4 = vmulq_f32(vdupq_n_f32(0.125f), vcvtq_f32_s32(vmull_s16(signs, scales4))); - qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); - qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]}); - qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]}); - qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]}); + qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); + qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]}); + qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]}); + qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]}); qx[1] = vandq_u8(vshrq_n_u8(qx[0], 4), mask); qx[0] = vandq_u8(qx[0], mask); qx[3] = vandq_u8(vshrq_n_u8(qx[2], 4), mask); qx[2] = vandq_u8(qx[2], mask); qx[5] = vandq_u8(vshrq_n_u8(qx[4], 4), mask); qx[4] = vandq_u8(qx[4], mask); @@ -2409,22 +2408,22 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI auto idxh = uint32x4_t{qh[0], qh[0] >> 4, qh[1], qh[1] >> 4}; auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(idxh, ms), ms), vdupq_n_u8(1))); signs = vaddq_s8(signs, vdupq_n_s8(-8)); - qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); - qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]}); - qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]}); - qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)], - iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]}); + qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); + qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]}); + qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]}); + qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)], + iq1s_grid_us_neon[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]}); auto shuffle = shuffle0; for (int j = 0; j < 4; ++j) { auto s = vqtbl1q_s8(signs, shuffle); @@ -2583,6 +2582,81 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, } } +template <int nrc_y> +void mul_mat_iq1_m_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + Q8<nrc_y, block_q8_K> q8(info); + int8x16x2_t qx[8]; + int32x4x4_t scales; + float32x4_t acc[nrc_y] = {}; + uint8x16x2_t scale_shuffle = {vreinterpretq_u8_u64(uint64x2_t{0x0100010001000100, 0x0302030203020302}), + vreinterpretq_u8_u64(uint64x2_t{0x0504050405040504, 0x0706070607060706})}; + uint64x2x2_t delta_mask = {uint64x2_t{0x0008, 0x0080}, uint64x2_t{0x0800, 0x8000}}; + iq1m_scale_t block_scale; + for (int ix = 0; ix < nrc_x; ++ix) { + auto iq1m = (const block_iq1_m *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < n/QK_K; ++ibl) { + const uint16_t * sc = (const uint16_t *)iq1m[ibl].scales; // 4 x uint16_t, each containing 4 scales + block_scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + float d = GGML_FP16_TO_FP32(block_scale.f16); + auto qs = iq1m[ibl].qs; + auto qh = iq1m[ibl].qh; + auto aux8 = vld1_u8(iq1m[ibl].scales); + auto aux16 = vcombine_u8(aux8, aux8); + uint16x8x2_t sc16 = { vreinterpretq_u16_u8(vqtbl1q_u8(aux16, scale_shuffle.val[0])), vreinterpretq_u16_u8(vqtbl1q_u8(aux16, scale_shuffle.val[1])) }; + sc16.val[0] = vmulq_u16(vandq_u16(sc16.val[0], vdupq_n_u64(0x0e0001c000380007)), vdupq_n_u64(0x0001000800400200)); + sc16.val[1] = vmulq_u16(vandq_u16(sc16.val[1], vdupq_n_u64(0x0e0001c000380007)), vdupq_n_u64(0x0001000800400200)); + sc16.val[0] = vaddq_u16(vshrq_n_u16(sc16.val[0], 8), vdupq_n_u16(1)); + sc16.val[1] = vaddq_u16(vshrq_n_u16(sc16.val[1], 8), vdupq_n_u16(1)); + scales.val[0] = vmovl_s16(vget_low_s16 (sc16.val[0])); + scales.val[1] = vmovl_s16(vget_high_s16(sc16.val[0])); + scales.val[2] = vmovl_s16(vget_low_s16 (sc16.val[1])); + scales.val[3] = vmovl_s16(vget_high_s16(sc16.val[1])); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + qx[2*ib64+0] = {vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)], + iq1s_grid_us[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)]}), + vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)], + iq1s_grid_us[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)]})}; + qx[2*ib64+1] = {vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)], + iq1s_grid_us[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)]}), + vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)], + iq1s_grid_us[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)]})}; + auto qh16 = (const uint16_t *)qh; + auto h1 = vdupq_n_u64(qh16[0]); + auto h2 = vdupq_n_u64(qh16[1]); + auto delta1 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h1, delta_mask.val[0]), delta_mask.val[0])), vdupq_n_s8(1))); + auto delta2 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h1, delta_mask.val[1]), delta_mask.val[1])), vdupq_n_s8(1))); + qx[2*ib64+0].val[0] = vsubq_s8(vshlq_n_s8(qx[2*ib64+0].val[0], 3), delta1); + qx[2*ib64+0].val[1] = vsubq_s8(vshlq_n_s8(qx[2*ib64+0].val[1], 3), delta2); + delta1 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h2, delta_mask.val[0]), delta_mask.val[0])), vdupq_n_s8(1))); + delta2 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h2, delta_mask.val[1]), delta_mask.val[1])), vdupq_n_s8(1))); + qx[2*ib64+1].val[0] = vsubq_s8(vshlq_n_s8(qx[2*ib64+1].val[0], 3), delta1); + qx[2*ib64+1].val[1] = vsubq_s8(vshlq_n_s8(qx[2*ib64+1].val[1], 3), delta2); + qs += 8; + qh += 4; + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = vdupq_n_s32(0); + for (int j = 0; j < 4; ++j) { + auto y1 = q8.load_quants(iy, ibl, 2*j+0); + auto dot1 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+0].val[0], y1.val[0]); + auto dot2 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+0].val[1], y1.val[1]); + auto y2 = q8.load_quants(iy, ibl, 2*j+1); + auto dot3 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+1].val[0], y2.val[0]); + auto dot4 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+1].val[1], y2.val[1]); + auto dot = vpaddq_s32(vpaddq_s32(dot1, dot2), vpaddq_s32(dot3, dot4)); + sumi = vmlaq_s32(sumi, dot, scales.val[j]); + } + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi)); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + inline float convert_to_q8_k_r8(float d0, const int8x16x2_t * qx, const int8_t * scales, uint32_t * block, uint32_t * q8_k) { auto max_i16 = vdupq_n_u16(0); int16x8x4_t q[8]; @@ -2774,6 +2848,12 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, func16 = mul_mat_iq1_s_r4_q8_1<16>; expected_Btype = GGML_TYPE_Q8_K128; break; + case GGML_TYPE_IQ1_M: + if (ne00%QK_K != 0) return false; + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_q8_K, funcs); + func16 = mul_mat_iq1_m_q8_K<16>; + expected_Btype = GGML_TYPE_Q8_K; + break; case GGML_TYPE_IQ1_M_R4: if (ne00%128 != 0) return false; IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_r4_q8_0, funcs); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 75ca87df..d7a5c1d8 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -279,7 +279,7 @@ struct MulMat { case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q6_K : return nrc_y >= 64 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; - case GGML_TYPE_IQ1_M : return nrc_y >= 8 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ1_M : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ2_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; |