summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp487
1 files changed, 49 insertions, 438 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index b77d08b6..d7682e54 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -755,18 +755,6 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
};
-struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> {
- DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
- template <typename Q8>
- inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) {
- new_block(i);
- }
- inline void new_block(int i) {
- bits.prepare(x[i].qs);
- }
- Q2Bits bits;
-};
-
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
@@ -1256,7 +1244,7 @@ struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
Q4Bits bits;
Scales8KBase s8k;
const __m512i values;
- const __m512i mask15 = _mm512_set1_epi16(0xfffe);
+ const __m512i mask15 = _mm512_set1_epi16(-2); // value is 0xfffe, but to shut up the stupid compiler warning we use the signed value
const __m512i mask1 = _mm512_set1_epi16(1);
const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
@@ -1319,22 +1307,13 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da
deq.new_block(i, q8, accm, scales);
for (int iy = 0; iy < nrc_y; ++iy) {
- if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
- auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i));
- auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(
- _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0),
- deq.bits.values[0], q8.load_quants64(iy, i, 0)), deq.bits.values[1], q8.load_quants64(iy, i, 1)),
- deq.bits.values[2], q8.load_quants64(iy, i, 2)), deq.bits.values[3], q8.load_quants64(iy, i, 3));
- accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
- } else {
- const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
- const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
- const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
- const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
- auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
- sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
- accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
- }
+ const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
+ sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
}
}
@@ -1347,64 +1326,6 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da
}
}
-template <int nrc_y>
-static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n % QK_K == 0);
- const int nb = n / QK_K;
-
- Q8<nrc_y> q8(info);
-
- DequantizerIQ2TN deq1(vx, bx), deq2(vx, bx);
-
- __m512 accd[2*nrc_y];
-
- for (int ix = 0; ix < nrc_x; ix += 2) {
-
- for (int iy = 0; iy < 2*nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
-
- deq1.new_row(ix+0);
- deq2.new_row(ix+1);
-
- for (int i = 0; i < nb; ++i) {
-
- deq1.new_block(i);
- deq2.new_block(i);
- //float d = 0.5f*(deq1.d + deq2.d); // The scale is supposed to be per per tensor, so we can use the same scale for both rows
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto sumi_scales_256 = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i));
- auto sumi_scales_512 = _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales_256, 0);
- auto q8q = q8.load_quants64(iy, i, 0);
- auto sumi_1 = _mm512_dpbusd_epi32(sumi_scales_512, deq1.bits.values[0], q8q);
- auto sumi_2 = _mm512_dpbusd_epi32(sumi_scales_512, deq2.bits.values[0], q8q);
- q8q = q8.load_quants64(iy, i, 1);
- sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[1], q8q);
- sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[1], q8q);
- q8q = q8.load_quants64(iy, i, 2);
- sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[2], q8q);
- sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[2], q8q);
- q8q = q8.load_quants64(iy, i, 3);
- sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[3], q8q);
- sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q);
- // The scale is supposed to be per per tensor, so we can use the same scale
- auto vd = _mm512_set1_ps(/*d* */q8.scale(iy, i));
- accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
- accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
- // Leaving this here just in case ternary models start using per row scales
- //accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
- //accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
- }
-
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix+0, iy, deq1.d*_mm512_reduce_add_ps(accd[2*iy+0]));
- info.store(ix+1, iy, deq2.d*_mm512_reduce_add_ps(accd[2*iy+1]));
- }
-
- }
-}
-
template <typename Dequantizer, int nrc_y>
static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
@@ -1478,33 +1399,19 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
- if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
- for (int kx = 0; kx < k_nx; ++kx) {
- compute_block_iq2tn(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, &accd);
- }
- } else {
- for (int kx = 0; kx < k_nx; ++kx) {
- compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
- }
+ for (int kx = 0; kx < k_nx; ++kx) {
+ compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
}
}
if (2*(nb/2) < nb) {
int i0 = 2*(nb/2);
deq[0]->new_block(i0, q8, &accm, scales);
- if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
- compute_block_iq2tn(0, i0, deq[0]->d, q8, deq[0]->bits.values, &accd);
- } else {
- compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
- }
+ compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
}
- if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
- info.store(ix, 0, _mm512_reduce_add_ps(accd));
- } else {
- auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
- info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
- }
+ auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
+ info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
}
}
@@ -2066,90 +1973,6 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
const __m256i mh = _mm256_set1_epi8(0x30);
};
-struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> {
- DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
-
- inline void prepare(int i, int j) {
- bits.prepare(x[i].qs, j);
- }
-
- Q2Bits bits;
-};
-
-
-template <int nrc_y>
-IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n%QK_K == 0);
- const int nb = n/QK_K;
-
- Q8<nrc_y> q8(info);
- DequantizerIQ2TN deq1(vx, bx), deq2(vx, bx);
-
- __m256 accd[nrc_y];
- const auto m1 = _mm256_set1_epi16(1);
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- deq1.new_row(ix);
- deq2.new_row(ix);
-
- for (int i = 0; i < nb; ++i) {
-
- if constexpr (nrc_y == 1) {
- deq1.prepare(i, 0);
- auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(0, i, 0)),
- _mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(0, i, 1)));
- sumi1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(0, i, 2)),
- _mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(0, i, 3))), sumi1);
-
- deq2.prepare(i, 1);
- auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(0, i, 4)),
- _mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(0, i, 5)));
- sumi2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(0, i, 6)),
- _mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(0, i, 7))), sumi2);
- auto sumi = _mm256_add_epi16(sumi2, _mm256_sub_epi16(sumi1, q8.load_bsums(0, i)));
- auto vd = _mm256_set1_ps(deq1.d*q8.scale(0, i));
- auto sf = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi));
- accd[0] = i > 0 ? _mm256_fmadd_ps(vd, sf, accd[0]) : _mm256_mul_ps(vd, sf);
- }
- else {
-
- deq1.prepare(i, 0); deq2.prepare(i, 1);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto vd = _mm256_set1_ps(deq1.d*q8.scale(iy, i));
- auto sumi = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(iy, i, 0)),
- _mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(iy, i, 1)));
- sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(iy, i, 2)),
- _mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(iy, i, 3))), sumi);
- sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(iy, i, 4)),
- _mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(iy, i, 5))), sumi);
- sumi = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(iy, i, 6)),
- _mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(iy, i, 7))), sumi);
- sumi = _mm256_sub_epi16(sumi, q8.load_bsums(iy, i));
-
- //auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[0], q8.load_quants(iy, i, 0)),
- // _mm256_maddubs_epi16(deq1.bits.values[1], q8.load_quants(iy, i, 1)));
- //auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq1.bits.values[2], q8.load_quants(iy, i, 2)),
- // _mm256_maddubs_epi16(deq1.bits.values[3], q8.load_quants(iy, i, 3)));
- //sumi1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[0], q8.load_quants(iy, i, 4)),
- // _mm256_maddubs_epi16(deq2.bits.values[1], q8.load_quants(iy, i, 5))), sumi1);
- //sumi2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq2.bits.values[2], q8.load_quants(iy, i, 6)),
- // _mm256_maddubs_epi16(deq2.bits.values[3], q8.load_quants(iy, i, 7))), sumi2);
- //auto sumi = _mm256_add_epi16(sumi2, _mm256_sub_epi16(sumi1, q8.load_bsums(iy, i)));
- auto sf = _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi));
- accd[iy] = i > 0 ? _mm256_fmadd_ps(vd, sf, accd[iy]) : _mm256_mul_ps(vd, sf);
- }
- }
-
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, hsum_float_8(accd[iy]));
- }
-
- }
-}
-
template <typename Dequantizer, int nrc_y>
static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
@@ -2471,7 +2294,7 @@ struct DequantizerIQ1BN {
};
-template <int nrc_y, bool is_iq1_tn>
+template <int nrc_y>
IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
@@ -2486,14 +2309,14 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
const block_iq1_bn * x;
const char * cx0 = (const char *)vx;
float scale;
+ ggml_half d16;
for (int ix = 0; ix < nrc_x; ++ix) {
const char * cx = cx0 + ix*bx;
- if constexpr (is_iq1_tn) {
- scale = GGML_FP16_TO_FP32(*(const ggml_half *)cx);
- cx += sizeof(ggml_half);
- }
+ std::memcpy(&d16, cx, sizeof(d16));
+ scale = GGML_FP16_TO_FP32(d16);
+ cx += sizeof(d16);
x = (const block_iq1_bn *)cx;
if constexpr (nrc_y == 1) {
@@ -2561,17 +2384,13 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
auto vd = q8.scale(iy);
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
- if constexpr (is_iq1_tn) {
- info.store(ix, iy, scale*hsum_float_4(sumf));
- } else {
- info.store(ix, iy, hsum_float_4(sumf));
- }
+ info.store(ix, iy, scale*hsum_float_4(sumf));
}
}
}
-struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn> {
+struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn, true> {
DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const {
@@ -2671,7 +2490,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
auto vd = q8.scale(iy);
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
- info.store(ix, iy, hsum_float_4(sumf));
+ info.store(ix, iy, deq.d*hsum_float_4(sumf));
}
}
}
@@ -4075,30 +3894,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerQ2K>(mm);
break;
- case GGML_TYPE_IQ2_TN:
- assert (ne00 % QK_K == 0);
-#ifdef HAVE_FANCY_SIMD
- //MulMat::set_functions<DequantizerIQ2TN>(mm);
- mm.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<DequantizerIQ2TN>;
- //mm.funcs[0] = mul_mat_iq2tn_q8_K_AVX512<1>;
- mm.funcs[1] = mul_mat_iq2tn_q8_K_AVX512<2>;
- mm.funcs[2] = mul_mat_iq2tn_q8_K_AVX512<3>;
- mm.funcs[3] = mul_mat_iq2tn_q8_K_AVX512<4>;
- mm.funcs[4] = mul_mat_iq2tn_q8_K_AVX512<5>;
- mm.funcs[5] = mul_mat_iq2tn_q8_K_AVX512<6>;
- mm.funcs[6] = mul_mat_iq2tn_q8_K_AVX512<7>;
- mm.funcs[7] = mul_mat_iq2tn_q8_K_AVX512<8>;
-#else
- mm.funcs[0] = mul_mat_iq2tn_q8_K<1>;
- mm.funcs[1] = mul_mat_iq2tn_q8_K<2>;
- mm.funcs[2] = mul_mat_iq2tn_q8_K<3>;
- mm.funcs[3] = mul_mat_iq2tn_q8_K<4>;
- mm.funcs[4] = mul_mat_iq2tn_q8_K<5>;
- mm.funcs[5] = mul_mat_iq2tn_q8_K<6>;
- mm.funcs[6] = mul_mat_iq2tn_q8_K<7>;
- mm.funcs[7] = mul_mat_iq2tn_q8_K<8>;
-#endif
- break;
case GGML_TYPE_Q3_K:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerQ3K>(mm);
@@ -4173,26 +3968,14 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
break;
case GGML_TYPE_IQ1_BN:
assert (ne00 % QK_IQ1BN == 0);
- mm.funcs[0] = mul_mat_iq1bn_q8_K64<1, false>;
- mm.funcs[1] = mul_mat_iq1bn_q8_K64<2, false>;
- mm.funcs[2] = mul_mat_iq1bn_q8_K64<3, false>;
- mm.funcs[3] = mul_mat_iq1bn_q8_K64<4, false>;
- mm.funcs[4] = mul_mat_iq1bn_q8_K64<5, false>;
- mm.funcs[5] = mul_mat_iq1bn_q8_K64<6, false>;
- mm.funcs[6] = mul_mat_iq1bn_q8_K64<7, false>;
- mm.funcs[7] = mul_mat_iq1bn_q8_K64<8, false>;
- expected_typeB = GGML_TYPE_Q8_K64;
- break;
- case GGML_TYPE_IQ1_TN:
- assert (ne00 % QK_IQ1BN == 0);
- mm.funcs[0] = mul_mat_iq1bn_q8_K64<1, true>;
- mm.funcs[1] = mul_mat_iq1bn_q8_K64<2, true>;
- mm.funcs[2] = mul_mat_iq1bn_q8_K64<3, true>;
- mm.funcs[3] = mul_mat_iq1bn_q8_K64<4, true>;
- mm.funcs[4] = mul_mat_iq1bn_q8_K64<5, true>;
- mm.funcs[5] = mul_mat_iq1bn_q8_K64<6, true>;
- mm.funcs[6] = mul_mat_iq1bn_q8_K64<7, true>;
- mm.funcs[7] = mul_mat_iq1bn_q8_K64<8, true>;
+ mm.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
+ mm.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
+ mm.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
+ mm.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
+ mm.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
+ mm.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
+ mm.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
+ mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
expected_typeB = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN:
@@ -5410,156 +5193,6 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
};
-struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> {
- DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
-
- constexpr static int num_blocks() { return 16; }
- constexpr static bool should_scale_quants() { return true; }
-
- //template <typename Q8>
- //inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) {
- // d = GGML_FP16_TO_FP32(x[i].d);
- //}
-
- inline void new_block(int) { }
-
- template <typename Q8>
- inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
- for (int iy = 0; iy < Q8::nrc_y; ++iy) {
- auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
- sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
- vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
-
- auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
- sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
- vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
-
- auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
- sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
- vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
-
- auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
- sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
- vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
- }
- }
- template <typename Q8>
- inline void compute1(const Q8& q8, int i, int j, int32x4_t * sumi) {
- auto q8b_1 = q8.load_quants(0, i, 4*j+0);
- sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
- vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
-
- auto q8b_2 = q8.load_quants(0, i, 4*j+1);
- sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
- vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
-
- q8b_1 = q8.load_quants(0, i, 4*j+2);
- sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_1.val[0]),
- vreinterpretq_s8_u8(bits.b2.val[1]), q8b_1.val[1]);
-
- q8b_2 = q8.load_quants(0, i, 4*j+3);
- sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_2.val[0]),
- vreinterpretq_s8_u8(bits.b2.val[3]), q8b_2.val[1]);
- }
-
- IQK_ALWAYS_INLINE void prepare(int i, int j) {
- bits.prepare(x[i].qs+32*j);
- auto m1 = vdupq_n_s8(1);
- for (int k = 0; k < 4; ++k) {
- bits.b1.val[k] = vsubq_s8(bits.b1.val[k], m1);
- bits.b2.val[k] = vsubq_s8(bits.b2.val[k], m1);
- }
- }
-
- Q2bits bits;
-};
-
-template <int nrc_y>
-void mul_mat_iq2tn_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n % QK_K == 0);
- const int nb = n / QK_K;
-
- Q8<nrc_y, block_q8_K> q8(info);
-
- DequantizerIQ2TN deq(vx, bx, nrc_y);
- float32x4_t acc[nrc_y];
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- deq.new_row(ix);
-
- for (int i = 0; i < nb; ++i) {
-
- int32x4_t sumi[nrc_y];
- for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
-
- deq.new_block(i);
- deq.prepare(i, 0);
- deq.compute(q8, i, 0, sumi);
- deq.prepare(i, 1);
- deq.compute(q8, i, 1, sumi);
-
- if (i > 0) {
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
- }
- } else {
- for (int iy = 0; iy < nrc_y; ++iy) {
- acc[iy] = vmulq_f32(vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
- }
- }
- }
-
- for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, vaddvq_f32(acc[iy]));
- }
- }
-}
-void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- assert(n % QK_K == 0);
- const int nb = n / QK_K;
-
- Q8<1, block_q8_K> q8(info);
-
- DequantizerIQ2TN deq(vx, bx, 1);
-
- auto m1 = vdup_n_s16(-1);
- float32x4_t acc[2];
-
- for (int ix = 0; ix < nrc_x; ++ix) {
-
- deq.new_row(ix);
-
- for (int i = 0; i < nb; ++i) {
-
- int32x4_t sumi[2] = {};
- deq.new_block(i);
- auto bsums = q8.load_bsums(0, i);
- bsums.val[0] = vaddq_s32(bsums.val[0], bsums.val[1]);
- sumi[0] = vmlal_s16(sumi[0], vget_low_s16 (bsums.val[0]), m1);
- sumi[1] = vmlal_s16(sumi[1], vget_high_s16(bsums.val[0]), m1);
- deq.bits.prepare(deq.x[i].qs);
- deq.compute1(q8, i, 0, sumi);
- deq.bits.prepare(deq.x[i].qs+32);
- deq.compute1(q8, i, 1, sumi);
-
- auto vd = vdupq_n_f32(deq.d*q8.scale(0, i));
- if (i > 0) {
- acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd);
- acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd);
- } else {
- acc[0] = vmulq_f32(vcvtq_f32_s32(sumi[0]), vd);
- acc[1] = vmulq_f32(vcvtq_f32_s32(sumi[1]), vd);
- }
-
- }
-
- acc[0] = vaddq_f32(acc[0], acc[1]);
- info.store(ix, 0, vaddvq_f32(acc[0]));
- }
-}
-
-
template <int nrc_y, typename Dequantizer>
void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
@@ -6630,7 +6263,7 @@ struct DequantizerIQ1BN {
}
};
-template <int nrc_y, bool is_iq1_tn>
+template <int nrc_y>
static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
@@ -6641,14 +6274,16 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
int8x16x4_t v1, v2;
float scale;
+ ggml_half d16;
+ char * c16 = (char *)&d16;
for (int ix = 0; ix < nrc_x; ++ix) {
const char * cx = ((const char *)vx + ix*bx);
- if constexpr (is_iq1_tn) {
- scale = GGML_FP16_TO_FP32(*(const ggml_half *)cx);
- cx += sizeof(ggml_half);
- }
+ c16[0] = cx[0]; c16[1] = cx[1];
+ //std::memcpy(&d16, cx, sizeof(d16));
+ cx += sizeof(d16);
+ scale = GGML_FP16_TO_FP32(d16);
const block_iq1_bn * x = (const block_iq1_bn *)cx;
@@ -6704,11 +6339,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
- if constexpr (is_iq1_tn) {
- info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
- } else {
- info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
- }
+ info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
}
}
@@ -6726,7 +6357,9 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int ix = 0; ix < nrc_x; ++ix) {
- const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx);
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ const float d = *dptr;
+ const block_iq2_bn * x = (const block_iq2_bn *)(dptr + 1);
if constexpr (nrc_y == 1) {
int8x16x4_t v1;
@@ -6789,7 +6422,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix, iy, -vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
+ info.store(ix, iy, -d*vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
}
}
}
@@ -6859,17 +6492,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_Q2_K:
MulMat::set_functions<DequantizerQ2K>(m);
break;
- case GGML_TYPE_IQ2_TN:
- //MulMat::set_functions<DequantizerIQ2TN>(m);
- m.funcs[0] = mul_mat_iq2tn_K_q8_K_1;
- m.funcs[1] = mul_mat_iq2tn_K_q8_K_T<2>;
- m.funcs[2] = mul_mat_iq2tn_K_q8_K_T<3>;
- m.funcs[3] = mul_mat_iq2tn_K_q8_K_T<4>;
- m.funcs[4] = mul_mat_iq2tn_K_q8_K_T<5>;
- m.funcs[5] = mul_mat_iq2tn_K_q8_K_T<6>;
- m.funcs[6] = mul_mat_iq2tn_K_q8_K_T<7>;
- m.funcs[7] = mul_mat_iq2tn_K_q8_K_T<8>;
- break;
case GGML_TYPE_Q3_K:
MulMat::set_functions<DequantizerQ3K>(m);
break;
@@ -6925,25 +6547,14 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
MulMat::set_functions<DequantizerIQ3S>(m);
break;
case GGML_TYPE_IQ1_BN:
- m.funcs[0] = mul_mat_iq1bn_q8_K64<1, false>;
- m.funcs[1] = mul_mat_iq1bn_q8_K64<2, false>;
- m.funcs[2] = mul_mat_iq1bn_q8_K64<3, false>;
- m.funcs[3] = mul_mat_iq1bn_q8_K64<4, false>;
- m.funcs[4] = mul_mat_iq1bn_q8_K64<5, false>;
- m.funcs[5] = mul_mat_iq1bn_q8_K64<6, false>;
- m.funcs[6] = mul_mat_iq1bn_q8_K64<7, false>;
- m.funcs[7] = mul_mat_iq1bn_q8_K64<8, false>;
- expected_Btype = GGML_TYPE_Q8_K64;
- break;
- case GGML_TYPE_IQ1_TN:
- m.funcs[0] = mul_mat_iq1bn_q8_K64<1, true>;
- m.funcs[1] = mul_mat_iq1bn_q8_K64<2, true>;
- m.funcs[2] = mul_mat_iq1bn_q8_K64<3, true>;
- m.funcs[3] = mul_mat_iq1bn_q8_K64<4, true>;
- m.funcs[4] = mul_mat_iq1bn_q8_K64<5, true>;
- m.funcs[5] = mul_mat_iq1bn_q8_K64<6, true>;
- m.funcs[6] = mul_mat_iq1bn_q8_K64<7, true>;
- m.funcs[7] = mul_mat_iq1bn_q8_K64<8, true>;
+ m.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
+ m.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
+ m.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
+ m.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
+ m.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
+ m.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
+ m.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
+ m.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
expected_Btype = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN: