summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp166
1 files changed, 156 insertions, 10 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index d16f01d9..0c1c1625 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -3228,6 +3228,16 @@ struct Q5_1_Dequantizer {
return _mm256_or_si256(b4.dequant(x->qs), vqh);
}
};
+struct Q6_1_Dequantizer {
+ Dequantizer4bit b4;
+ const __m256i mh = _mm256_set1_epi8(0x30);
+ inline __m256i dequant(const block_q6_0 * x) const {
+ uint64_t aux64; std::memcpy(&aux64, x->qh, 8);
+ auto h128 = _mm_set_epi64x(aux64, aux64 << 4);
+ auto h256 = MM256_SET_M128I(_mm_srli_epi16(h128, 2), h128);
+ return _mm256_or_si256(b4.dequant(x->qs), _mm256_and_si256(h256, mh));
+ }
+};
template <typename Q, typename Scales, typename Dequantizer>
struct Q_Unpacker {
@@ -3332,6 +3342,11 @@ struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_
using Sum4T = Sum4Type1;
inline static int block_size() { return QK4_1; }
};
+struct Q6_0_1_Unpacker final : public Q_Unpacker<block_q6_0, ScaleHelperQ_0_1<32>, Q6_1_Dequantizer> {
+ Q6_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ81;
+ inline static int block_size() { return QK5_0; }
+};
// float matrices - we handle f16, bf16 (if native bf16 support is available) and f32, but only to f32 result
@@ -3628,7 +3643,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
}
else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker> ||
std::is_same_v<Dequantizer, Q8_0_1_Unpacker> || std::is_same_v<Dequantizer, Q4_0_1_Unpacker> ||
- std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) {
+ std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, IQ4_NL_Unpacker> ||
+ std::is_same_v<Dequantizer, Q6_0_1_Unpacker>) {
m.funcs[0] = mul_mat_qX_1_q8_1_T<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_1_q8_1_T<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_1_q8_1_T<Dequantizer, 3>;
@@ -3893,8 +3909,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
break;
case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0);
- //MulMat::set_functions<Q4_0_Unpacker>(mm);
- //expected_typeB = GGML_TYPE_Q8_0;
MulMat::set_functions<Q4_0_1_Unpacker>(mm);
expected_typeB = GGML_TYPE_Q8_1;
break;
@@ -3905,8 +3919,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
break;
case GGML_TYPE_Q5_0:
assert (ne00 % QK5_0 == 0);
- //MulMat::set_functions<Q5_0_Unpacker>(mm);
- //expected_typeB = GGML_TYPE_Q8_0;
MulMat::set_functions<Q5_0_1_Unpacker>(mm);
expected_typeB = GGML_TYPE_Q8_1;
break;
@@ -3915,10 +3927,13 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
MulMat::set_functions<Q5_1_Unpacker>(mm);
expected_typeB = GGML_TYPE_Q8_1;
break;
+ case GGML_TYPE_Q6_0:
+ assert (ne00 % QK6_0 == 0);
+ MulMat::set_functions<Q6_0_1_Unpacker>(mm);
+ expected_typeB = GGML_TYPE_Q8_1;
+ break;
case GGML_TYPE_Q8_0:
assert (ne00 % QK8_0 == 0);
- //MulMat::set_functions<Q8_0_Unpacker>(mm);
- //expected_typeB = GGML_TYPE_Q8_0;
MulMat::set_functions<Q8_0_1_Unpacker>(mm);
expected_typeB = GGML_TYPE_Q8_1;
break;
@@ -5417,6 +5432,34 @@ struct DequantizerQ40 final : public BaseLegacyDequantizer<block_q4_0> {
//ggml_half aux[4];
};
+struct DequantizerQ60 final : public BaseLegacyDequantizer<block_q6_0> {
+
+ DequantizerQ60(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ auto qh8 = vld1_u8(x[i].qh);
+ auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
+ q[0] = vaddq_s8(vorrq_u8(q[0], vandq_u8(qh, hmask)), m32);
+ q[1] = vaddq_s8(vorrq_u8(q[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32);
+ }
+ inline void prepare1(int i) {
+ prepare1(i, bits.b);
+ }
+
+ inline float16x4_t new_block(int i) {
+ ggml_half aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = x[4*i+k].d;
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vld1_f16((const float16_t *)aux);
+ }
+
+ const int8x16_t m32 = vdupq_n_s8(-32);
+ const uint8x16_t hmask = vdupq_n_u8(0x30);
+};
+
struct DequantizerIQ4NL final : public BaseLegacyDequantizer<block_iq4_nl> {
DequantizerIQ4NL(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
@@ -6325,7 +6368,8 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||
- std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL>) {
+ std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL> ||
+ std::is_same_v<Dequantizer, DequantizerQ60>) {
m.funcs[0] = mul_mat_qX_0_q8_0<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_0_q8_0<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_0_q8_0<Dequantizer, 3>;
@@ -6492,6 +6536,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
MulMat::set_functions<DequantizerQ51>(m);
expected_Btype = GGML_TYPE_Q8_1;
break;
+ case GGML_TYPE_Q6_0:
+ MulMat::set_functions<DequantizerQ60>(m);
+ expected_Btype = GGML_TYPE_Q8_0;
+ break;
case GGML_TYPE_Q8_0:
MulMat::set_functions<DequantizerQ80>(m);
expected_Btype = GGML_TYPE_Q8_0;
@@ -7227,6 +7275,64 @@ struct HelperIQ4nl final : public BaseHelper<step> {
#endif
};
+template <int D, int step>
+struct HelperQ60 final : public BaseHelper<step> {
+#ifdef __aarch64__
+ using block_q8 = block_q8_0;
+#else
+ using block_q8 = block_q8_1;
+#endif
+ using Base = BaseHelper<step>;
+ HelperQ60(const char * data, int stride) : Base(data, stride) {}
+
+ // Needed for v * softmax(k * q)
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ int j = F16::block_size*i;
+ auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0;
+#ifdef __aarch64__
+ // TODO
+ auto vd = F16::set1(*(const float16_t *)&dl->d);
+ auto qh8 = vld1_u8(dl->qh);
+ auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
+ auto qs = vld1q_u8(dl->qs);
+ qs = j%QK4_0 ? vshrq_n_u8(qs, 4) : vandq_u8(qs, mask_l);
+ qs = vorrq_u8(qs, vandq_u8(mask_h, j%QK4_0 ? vshrq_n_u8(qh, 2) : qh));
+ qs = vaddq_s8(qs, m32);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(qs))));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(qs))));
+#else
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
+ auto bl = _mm_loadu_si128((const __m128i *)dl->qs);
+ uint64_t aux64; std::memcpy(&aux64, dl->qh, 8);
+ auto bh = _mm_set_epi64x(aux64, aux64 << 4);
+#ifdef HAVE_FANCY_SIMD
+ auto ql = _mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32);
+ auto qh = _mm_add_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(bl, 4), mask_l), _mm_and_si128(_mm_srli_epi16(bh, 2), mask_h)), m32);
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
+#else
+ if (j%QK4_0) {
+ bl = _mm_srli_epi16(bl, 4);
+ bh = _mm_srli_epi16(bh, 2);
+ }
+ auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32));
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
+#endif
+#endif
+ }
+
+#ifdef __AVX2__
+ const __m128i mask_l = _mm_set1_epi8(0x0f);
+ const __m128i mask_h = _mm_set1_epi8(0x30);
+ const __m128i m32 = _mm_set1_epi8(-32);
+#else
+ const uint8x16_t mask_l = vdupq_n_u8(0x0f);
+ const uint8x16_t mask_h = vdupq_n_u8(0x30);
+ const int8x16_t m32 = vdupq_n_s8(-32);
+#endif
+};
+
template <int q_step, int k_step>
struct FlashMS {
// Something goes wrong when storing and manipulating K*Q as fp16.
@@ -7759,6 +7865,14 @@ struct FlashQKfp32 {
mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
#endif
}
+ else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
+#ifdef __aarch64__
+ mul_mat_qX_0_q8_0<DequantizerQ60, q_step>(D, kh.block, kh.stride, info, k_step);
+#else
+ mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+#endif
+ }
else {
GGML_ASSERT(false);
}
@@ -7880,6 +7994,28 @@ struct FlashQKfp32 {
#endif
}
}
+ else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
+ switch (nq) {
+#ifdef __aarch64__
+ case 1: mul_mat_qX_0_q8_0<DequantizerQ60, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_0_q8_0<DequantizerQ60, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_0_q8_0<DequantizerQ60, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_0_q8_0<DequantizerQ60, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_0_q8_0<DequantizerQ60, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_0_q8_0<DequantizerQ60, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_0_q8_0<DequantizerQ60, 7>(D, kh.block, kh.stride, info, k_step); break;
+#else
+ case 1: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
+ case 2: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
+ case 3: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
+ case 4: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
+ case 5: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
+ case 6: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
+ case 7: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
+#endif
+ }
+ }
else {
GGML_ASSERT(false);
}
@@ -8019,7 +8155,8 @@ struct FlashAttn {
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
- std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
+ std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
+ std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
} else {
@@ -8364,6 +8501,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperIQ4nl<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_Q6_0: {
+ HelperQ60<D, k_step> vh(v, stride_v);
+ iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ } break;
default: break;
}
}
@@ -8395,6 +8536,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperIQ4nl<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_Q6_0: {
+ HelperQ60<D, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ } break;
default: break;
}
@@ -8404,7 +8549,8 @@ inline bool flash_attn_is_supported(ggml_type type) {
#ifdef __AVX512BF16__
if (type == GGML_TYPE_BF16) return true;
#endif
- if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL) return true;
+ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
+ type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
return false;
}
}