diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-06-21 16:35:08 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-21 16:35:08 +0200 |
commit | 4f97409b80dffa96abe1a31d0a06e6dde78e91b7 (patch) | |
tree | 533bcb5cc7cc0bccf307317ae5b8d37403fd19b5 | |
parent | a98b7678a305c560117ce0a63a3529f2aaa17acb (diff) |
Faster ARM_NEON GEMM implementation for legacy quants (#546)
* iq2_kt and iq3_kt work with new int trellis
Much slower than the fp16 based trellis. I guess, Apple doesn't
have int8_t SIMD on the M2-Max GPU.
* q4_0
83.6 t/s -> 128.4 t/s. q4_0_r8 is at 123.5 t/s
* q5_0
74.2 t/s -> 128.5 t/s. q5_0_r4 is at 111.4 t/s.
* q6_0
74.2 t/s -> 128.8 t/s. q6_0_r4 is at 107.2 t/s.
* q8_0
84.5 -> 128.7 t/s. q8_0_r8 is at 131 t/s.
* iq4_nl
84.5 t/s -> 128.1 t/s. iq4_nl_r4 is at 120.4 t/s
* q4_1
74.4 -> 115.4 t/s. There is no repacked variant
* q5_1
64.2 t/s -> 114.9 t/s. There is no repacked variant.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-metal.metal | 29 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_gemm_legacy_quants.cpp | 249 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 14 | ||||
-rw-r--r-- | src/llama.cpp | 2 |
4 files changed, 260 insertions, 34 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c3c4f0bb..e3bd070d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6598,16 +6598,12 @@ void kernel_mul_mv_iq2_k_f32_impl( struct Trellis3 { constexpr constant static uint32_t kmask = 0x3f3f3f3f; - constexpr constant static uint32_t ka = 89226354; - constexpr constant static uint32_t kb = 64248484; + constexpr constant static uint32_t ka = 0xCBAC1FED; constexpr constant static uint32_t ka1 = ka*ka; - constexpr constant static uint32_t kb1 = kb*ka+kb; constexpr constant static uint32_t ka2 = ka1*ka; - constexpr constant static uint32_t kb2 = kb1*ka+kb; constexpr constant static uint32_t ka3 = ka2*ka; - constexpr constant static uint32_t kb3 = kb2*ka+kb; static inline char4 gen4(uint32_t val) { - thread uint32_t aux[4] = {(ka*val + kb) & kmask, (ka1*val + kb1) & kmask, (ka2*val + kb2) & kmask, (ka3*val + kb3) & kmask}; + thread uint32_t aux[4] = {(ka*val) & kmask, (ka1*val) & kmask, (ka2*val) & kmask, (ka3*val) & kmask}; thread const int8_t * a8 = (thread const int8_t *)aux; char4 result; for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3]; @@ -6615,14 +6611,18 @@ struct Trellis3 { } template <typename T4> static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) { - thread uint32_t aux[4] = {ka*val + kb, ka1*val + kb1, ka2*val + kb2, ka3*val + kb3}; + thread uint32_t aux[4] = {ka*val, ka1*val, ka2*val, ka3*val}; uint32_t aux32[2]; thread const int8_t * a8 = (thread const int8_t *)aux32; + //thread const char4 * a8 = (thread const char4 *)aux32; for (int i = 0; i < 4; ++i) { aux32[0] = aux[i] & kmask; - aux32[1] = (ka3*aux[i] + kb3) & kmask; + aux32[1] = (ka3*aux[i]) & kmask; v1[i] = -126 + a8[0] + a8[1] + a8[2] + a8[3]; v2[i] = -126 + a8[4] + a8[5] + a8[6] + a8[7]; + // Much slower: + //v1[i] = -126 + a8[0][0] + a8[0][1] + a8[0][2] + a8[0][3]; + //v2[i] = -126 + a8[1][0] + a8[1][1] + a8[1][2] + a8[1][3]; } } }; @@ -6837,7 +6837,7 @@ void kernel_mul_mv_iq3_kt_f32_impl( float drow[N_DST]; for (int row = 0; row < N_DST; ++row) { device const float * dptr = (device const float *)(cx + row*row_size); - drow[row] = dptr[0] * 31.75f * 1.01f; + drow[row] = dptr[0] * 1.01f; } device const block_iq3_kt * x = (device const block_iq3_kt *)(cx + sizeof(float)); @@ -6854,7 +6854,7 @@ void kernel_mul_mv_iq3_kt_f32_impl( const float ls = drow[row] * ((sc[(it/2)%4] >> 4*(it/8)) & 0xf); const uint8_t mask = 1 << (it/2); - Trellis::gen8(q2[2*it+0]+4096, v[0], v[1]); + Trellis3::gen8(q2[2*it+0]+4096, v[0], v[1]); for (int j = 0; j < 8; ++j) { u32[j] &= 0x7fffffff; u32[j] |= qh[j+0] & mask ? 0x80000000 : 0; @@ -6862,7 +6862,7 @@ void kernel_mul_mv_iq3_kt_f32_impl( auto sum = v[0]*y4[0] + v[1]*y4[1]; - Trellis::gen8(q2[2*it+1]+4096, v[0], v[1]); + Trellis3::gen8(q2[2*it+1]+4096, v[0], v[1]); for (int j = 0; j < 8; ++j) { u32[j] &= 0x7fffffff; u32[j] |= qh[j+8] & mask ? 0x80000000 : 0; @@ -8593,17 +8593,14 @@ template <typename type4x4> void dequantize_iq3_kt(device const block_iq3_kt * x, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 int ib32 = il/2; - half scale = (half)((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf) * 31.75h * 1.01h; + half scale = (half)((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf) * 1.01h; device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2); device const uint8_t * qh = x->qh + 16*(il%2); const uint8_t mask = 1 << ib32; half4 v1, v2; for (int i = 0; i < 2; ++i) { - Trellis::gen8(q2[i]+4096, v1, v2); - //v1 *= scale; v2 *= scale; - //for (int j = 0; j < 4; ++j) reg[2*i+0][j] = qh[8*i+0+j] & mask ? -abs(v1[j]) : abs(v1[j]); - //for (int j = 0; j < 4; ++j) reg[2*i+1][j] = qh[8*i+4+j] & mask ? -abs(v2[j]) : abs(v2[j]); + Trellis3::gen8(q2[i]+4096, v1, v2); v1 = abs(v1)*scale; v2 = abs(v2)*scale; for (int j = 0; j < 4; ++j) reg[2*i+0][j] = qh[8*i+0+j] & mask ? -v1[j] : v1[j]; for (int j = 0; j < 4; ++j) reg[2*i+1][j] = qh[8*i+4+j] & mask ? -v2[j] : v2[j]; diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index 312c556e..ab6eb130 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -2782,21 +2782,239 @@ void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf } } +typedef struct { + ggml_half d[16]; + int8_t qs[256]; +} block_q8_1_r8; + +template <int nrc_y> +void mul_mat_q8_1_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8<nrc_y, block_q8_1_x4> q8(info); + int nb = n / QK8_0; + float32x4_t acc[2*nrc_y] = {}; + int8x16_t qx[16]; + float d8[8*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_1_r8 * iq8 = (const block_q8_1_r8 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + vst1q_f32(d8+8*iy+0, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d+0))); + vst1q_f32(d8+8*iy+4, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d+4))); + } + for (int k = 0; k < 4; ++k) { + auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d); + auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); + auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); + auto m16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d+8); + auto m1 = vcvt_f32_f16(vget_low_f16 (m16)); + auto m2 = vcvt_f32_f16(vget_high_f16(m16)); + for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j); + int32x4_t sumi1, sumi2; + for (int iy = 0; iy < nrc_y; ++iy) { + qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2); + auto dy = vdupq_n_f32(d8[8*iy+k]); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); + auto my = vdupq_n_f32(d8[8*iy+k+4]); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], m1, my); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], m2, my); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d); + auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16)); + auto scales2 = vcvt_f32_f16(vget_high_f16(scales16)); + auto m16 = vld1q_f16((const float16_t *)iq8[ib].d+8); + auto m1 = vcvt_f32_f16(vget_low_f16 (m16)); + auto m2 = vcvt_f32_f16(vget_high_f16(m16)); + for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j); + int32x4_t sumi1, sumi2; + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2); + auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1)); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2)); + auto my = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].s)); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], m1, my); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], m2, my); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix+0, iy, acc[2*iy+0]); + info.store(ix+4, iy, acc[2*iy+1]); + acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f); + } + } } -bool iqk_convert_legacy_quants_q8_r8([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) { - return false; - //switch (type) { - // case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8<block_q4_0, Q4_0_Dequantizer>(n, vx, bx, vy, nrc_x); break; - // case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8<block_q4_1, Q4_1_Dequantizer>(n, vx, bx, vy, nrc_x); break; - // case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, Q5_0_Dequantizer>(n, vx, bx, vy, nrc_x); break; - // case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, Q5_1_Dequantizer<block_q5_1>>(n, vx, bx, vy, nrc_x); break; - // case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, Q6_0_Dequantizer>(n, vx, bx, vy, nrc_x); break; - // case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, IQ4_NL0_Dequantizer>(n, vx, bx, vy, nrc_x); break; - // case GGML_TYPE_Q8_0 : iqk_convert_q80_q80_r8(n, vx, bx, vy, nrc_x); break; - // default: return false; - //} - //return true; +struct DeqQ40 { + const int8x16_t m8 = vdupq_n_s8(-8); + const uint8x16_t ml = vdupq_n_s8(0xf); + inline int8x16x2_t dequant(const block_q4_0& x) const { + auto bits = vld1q_u8(x.qs); + return { vaddq_s8(vreinterpretq_s8_u8(vandq_u8(bits, ml)), m8), vaddq_s8(vreinterpretq_s8_u8(vshrq_n_u8(bits, 4)), m8) }; + } +}; + +struct DeqQ41 { + const uint8x16_t ml = vdupq_n_s8(0xf); + inline int8x16x2_t dequant(const block_q4_1& x) const { + auto bits = vld1q_u8(x.qs); + return { vreinterpretq_s8_u8(vandq_u8(bits, ml)), vreinterpretq_s8_u8(vshrq_n_u8(bits, 4)) }; + } +}; + +struct DeqIQ4NL { + const int8x16_t mt = load_values(); + const uint8x16_t ml = vdupq_n_s8(0xf); + inline int8x16x2_t dequant(const block_iq4_nl& x) const { + auto bits = vld1q_u8(x.qs); + return { vqtbl1q_s8(mt, vandq_u8(bits, ml)), vqtbl1q_s8(mt, vshrq_n_u8(bits, 4)) }; + } + static inline int8x16_t load_values() { return vld1q_s8(iq4k_values); } +}; + +struct DeqQ50 { + + inline int8x16x2_t dequant(const block_q5_0& x) const { + int8x16x2_t r; + bits.prepare1(x.qs, r.val); + auto qh = x.qh; + r.val[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0)))); + r.val[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2)))); + return r; + } + + Q4LegacyBits bits; + HighBit5Legacy hbits; + const uint8x16_t mh = vdupq_n_u8(0xf0); +}; + +struct DeqQ51 { + + inline int8x16x2_t dequant(const block_q5_1& x) const { + int8x16x2_t r; + bits.prepare1(x.qs, r.val); + auto qh = x.qh; + r.val[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[0]), vandq_u8(mh, hbits.to_bytes(qh+0)))); + r.val[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[1]), vandq_u8(mh, hbits.to_bytes(qh+2)))); + return r; + } + + Q4LegacyBits bits; + HighBit5Legacy hbits; + const uint8x16_t mh = vdupq_n_u8(0x10); +}; + +struct DeqQ60 { + + inline int8x16x2_t dequant(const block_q6_0& x) const { + int8x16x2_t r; + bits.prepare1(x.qs, r.val); + auto qh8 = vld1_u8(x.qh); + auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); + r.val[0] = vaddq_s8(vorrq_u8(r.val[0], vandq_u8(qh, hmask)), m32); + r.val[1] = vaddq_s8(vorrq_u8(r.val[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32); + return r; + } + + Q4LegacyBits bits; + const int8x16_t m32 = vdupq_n_s8(-32); + const uint8x16_t hmask = vdupq_n_u8(0x30); +}; + +struct DeqQ80 { + inline int8x16x2_t dequant(const block_q8_0& x) const { + return vld1q_s8_x2(x.qs); + } +}; + +template <typename Block, typename Dequantizer> +void iqk_convert_qX_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK4_0 == 0); + GGML_ASSERT(nrc_x%8 == 0); + + const int nb = n/QK8_0; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const Block * x8[8]; + + uint32_t block[8]; + + Dequantizer deq; + + for (int ix = 0; ix < nrc_x; ix += 8) { + + for (int k = 0; k < 8; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + y[i].d[k] = x8[k][i].d; + vst1q_s8_x2((int8_t *)block, deq.dequant(x8[k][i])); + auto qs = (uint32_t *)y[i].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + y += nb; + } +} + +template <typename Block, typename Dequantizer> +void iqk_convert_qX_1_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK4_0 == 0); + GGML_ASSERT(nrc_x%8 == 0); + + const int nb = n/QK8_0; + + block_q8_1_r8 * y = (block_q8_1_r8 *)vy; + + const Block * x8[8]; + + uint32_t block[8]; + + Dequantizer deq; + + for (int ix = 0; ix < nrc_x; ix += 8) { + + for (int k = 0; k < 8; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + y[i].d[k+0] = x8[k][i].d; + y[i].d[k+8] = x8[k][i].m; + vst1q_s8_x2((int8_t *)block, deq.dequant(x8[k][i])); + auto qs = (uint32_t *)y[i].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + y += nb; + } +} + +} + +bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) { + switch (type) { + case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8<block_q4_0, DeqQ40>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8<block_q4_1, DeqQ41>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, DeqQ50>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, DeqQ51>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, DeqQ60>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, DeqIQ4NL>(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q8_0 : iqk_convert_qX_q80_r8<block_q8_0, DeqQ80>(n, vx, bx, vy, nrc_x); break; + default: return false; + } + return true; } bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { @@ -2804,7 +3022,7 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu if (ne00%QK8_0 != 0) return false; auto etypeA = ggml_type(typeA); - auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4; + auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 || etypeA == GGML_TYPE_Q8_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4; if (ggml_type(typeB) != expected_typeB) return false; func16 = nullptr; @@ -2843,6 +3061,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu case GGML_TYPE_Q8_0_R8: IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_0, kernels); break; + case GGML_TYPE_Q8_1: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_1_r8_q8_1, kernels); + break; case GGML_TYPE_IQ4_NL_R4: IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer, kernels); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 41c4f980..ce0753a5 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -271,9 +271,16 @@ struct MulMat { } #else switch (type) { - case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; - case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; - case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q4_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q4_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; + case GGML_TYPE_Q5_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q5_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; + case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } #endif @@ -913,6 +920,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q6_0_R4: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_Q8_1: case GGML_TYPE_IQ4_NL_R4: return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16); case GGML_TYPE_IQ1_BN: diff --git a/src/llama.cpp b/src/llama.cpp index c0f147b9..a70d2582 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18722,7 +18722,7 @@ static std::pair<ggml_type, int> interleaved_properties(ggml_type type) { { GGML_TYPE_IQ5_KS_R4, { GGML_TYPE_IQ5_KS, 4} }, { GGML_TYPE_IQ5_K_R4, { GGML_TYPE_IQ5_K, 4} }, { GGML_TYPE_Q8_KV_R8, { GGML_TYPE_Q8_KV, 8} }, - { GGML_TYPE_Q8_K_R8, { GGML_TYPE_Q8_K, 8} }, + { GGML_TYPE_Q8_K_R8, { GGML_TYPE_Q8_0, 8} }, { GGML_TYPE_BF16_R16, { GGML_TYPE_BF16, 16} }, }; if (auto it = k_map.find(type); it != k_map.end()) return it->second; |