summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml-cuda/common.cuh14
-rw-r--r--ggml/src/ggml-cuda/convert.cu23
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu137
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cuh15
-rw-r--r--ggml/src/ggml-cuda/mmq.cu12
-rw-r--r--ggml/src/ggml-cuda/mmq.cuh253
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu12
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu5
-rw-r--r--ggml/src/ggml-metal.metal77
-rw-r--r--ggml/src/ggml.c25
-rw-r--r--ggml/src/iqk/iqk_gemm_ktquants.cpp1043
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp12
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp161
-rw-r--r--src/llama.cpp1
16 files changed, 1668 insertions, 132 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 8f3d2a26..a0cdab28 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -579,6 +579,20 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_KT> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KT> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 01b7250e..b40079a3 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -340,6 +340,12 @@ inline __device__ int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000;
}
+int __device__ __forceinline__ trellis_next_int(uint32_t& val) {
+ constexpr uint32_t ka = 0xCBAC1FED;
+ val = ka*val;
+ return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126);
+}
+
float __device__ __forceinline__ trellis_next(uint32_t& val) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
@@ -367,9 +373,9 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst
dst_t * y = yy + ii*QK_K + 8*ib;
const uint16_t * ql = (const uint16_t *)x[i].ql;
uint32_t idx = ql[ib] + 4096;
- const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f;
+ const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 1.05f;
for (int j = 0; j < 8; ++j) {
- y[j] = dl * trellis_next(idx);
+ y[j] = dl * trellis_next_int(idx);
}
}
@@ -388,10 +394,10 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst
dst_t * y = yy + ii*QK_K + 8*ib;
const uint16_t * ql = (const uint16_t *)x[i].ql;
uint32_t idx = ql[ib] + 4096;
- const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f;
+ const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 1.01f; //1.015f;
uint8_t mask = 1 << (ib/4);
for (int j = 0; j < 8; ++j) {
- y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
+ y[j] = dl * std::abs(trellis_next_int(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
}
}
@@ -401,9 +407,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
int64_t ii = blockIdx.x;
int64_t row = (QK_K * ii) / n_per_row;
const float * dptr = (const float *)((const char *)vx + row * row_size);
- float scale = dptr[0] * 31.75f * 1.01f;
- float row_av = dptr[1];
- const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
+ float scale = dptr[0] * 1.00f;
+ const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1);
const int64_t i = ii - (row*n_per_row)/QK_K;
constexpr int kNumGroups = 64;
@@ -423,8 +428,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
int ls = ((shb[ib32] & 0xff) >> 1) - 64;
const float dl = scale * ls;
for (int j = 0; j < 4; ++j) {
- y[j+0] = dl * trellis_next(idx1) + row_av;
- y[j+4] = dl * trellis_next(idx2) + row_av;
+ y[j+0] = dl * trellis_next_int(idx1);
+ y[j+4] = dl * trellis_next_int(idx2);
}
}
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index e5a224b4..c19215de 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -433,6 +433,119 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1(
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}
+__device__ __forceinline__ void vec_dot_iq4_kt_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+ float scale = *(const float *)vbq;
+ const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + sizeof(float)) + kbx;
+
+ // iqs is 0...28
+ const int ib32 = iqs/4; // Why iqs/4 ?
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
+ //const int8_t * q8 = bq8_1[ib32].qs;
+ const int ls = (bq4->qs[ib32] & 0xff) >> 1;
+ const float dl = scale * (ls - 64);
+ const uint32_t idx0 = ((bq4->qs[ib32] & 1) << 15) + 4096;
+ auto ql = (const uint8_t *)(bq4->qs + 8);
+ auto qh = ql + 64;
+ ql += 8*ib32;
+ qh += 8*(ib32%4);
+ const int shift1 = 8 - 4*(ib32/4);
+ int sumi = 0;
+ for (int j = 0; j < 8; ++j) {
+ const uint32_t sh = bq4->qs[ib32] >> (8 + 3*j);
+ uint32_t val = ql[j] + ((qh[j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0;
+ int v4 = 0;
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ //int s = val & km;
+ //sumi += q8[4*j+k] * ggml_cuda_dp4a(s, 0x01010101, -126);
+ v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+ sumi = ggml_cuda_dp4a(v4, q8[j], sumi);
+ }
+ *result += dl * __low2float(bq8_1[ib32].ds) * sumi;
+}
+
+__device__ __forceinline__ void vec_dot_iq2_kt_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+ float scale = *(const float *)vbq;
+ const block_iq2_kt * bq2 = (const block_iq2_kt *)((const char *)vbq + sizeof(float)) + kbx;
+
+ // iqs is 0...28
+ const int ib32 = iqs/4;
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
+ const int ls = iq4k_values[(bq2->scales[ib32%4] >> 4*(ib32/4)) & 0xf];
+ const float dl = scale * ls * 1.05f;
+ auto ql = (const uint16_t *)bq2->ql;
+ int sumi = 0;
+ for (int j = 0; j < 4; ++j) {
+ uint32_t val = ql[4*ib32+j] + 4096;
+ int v4 = 0;
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+ sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi);
+ v4 = 0;
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+ sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi);
+ }
+ *result += dl * __low2float(bq8_1[ib32].ds) * sumi;
+}
+
+__device__ __forceinline__ void vec_dot_iq3_kt_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+ float scale = *(const float *)vbq;
+ const block_iq3_kt * bq3 = (const block_iq3_kt *)((const char *)vbq + sizeof(float)) + kbx;
+
+ // iqs is 0...28
+ const int ib32 = iqs/4;
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
+ const int ls = (bq3->scales[ib32%4] >> 4*(ib32/4)) & 0xf;
+ const float dl = scale * ls * 1.015f;
+ auto ql = (const uint16_t *)bq3->ql;
+ uint32_t mask = 0x01010101 << ib32;
+ const uint32_t * qh = (const uint32_t *)bq3->qh;
+ int sumi = 0;
+ for (int j = 0; j < 4; ++j) {
+ uint32_t val = ql[4*ib32+j] + 4096;
+ int v4 = 0;
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126));
+ v4 |= q << 8*k;
+ }
+ uint32_t signs = __vcmpne4(qh[2*j+0] & mask, 0);
+ v4 = __vsub4(v4 ^ signs, signs);
+ sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi);
+ v4 = 0;
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126));
+ v4 |= q << 8*k;
+ }
+ signs = __vcmpne4(qh[2*j+1] & mask, 0);
+ v4 = __vsub4(v4 ^ signs, signs);
+ sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi);
+ }
+ *result += dl * __low2float(bq8_1[ib32].ds) * sumi;
+}
+
#define VDR_IQ4_KSS_Q8_1_MMVQ 4
#define VDR_IQ4_KSS_Q8_1_MMQ 4
@@ -1217,6 +1330,30 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
+void mul_mat_vec_iq4_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+}
+
+void mul_mat_vec_iq2_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq2_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+}
+
+void mul_mat_vec_iq3_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq3_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+}
+
void mul_mat_vec_iq4_kss_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh
index 17bf5ad2..e7c6e1d2 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cuh
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh
@@ -100,3 +100,18 @@ void mul_mat_vec_iq1_m_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
+
+void mul_mat_vec_iq2_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
+
+void mul_mat_vec_iq3_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
+
+void mul_mat_vec_iq4_kt_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const char * ids_data,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index a13be11b..9103e3f1 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -100,6 +100,15 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ4_KS_R4:
mul_mat_q_case<GGML_TYPE_IQ4_KS_R4>(ctx, args, stream);
break;
+ case GGML_TYPE_IQ4_KT:
+ mul_mat_q_case<GGML_TYPE_IQ4_KT>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ2_KT:
+ mul_mat_q_case<GGML_TYPE_IQ2_KT>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ3_KT:
+ mul_mat_q_case<GGML_TYPE_IQ3_KT>(ctx, args, stream);
+ break;
case GGML_TYPE_IQ5_KS:
mul_mat_q_case<GGML_TYPE_IQ5_KS>(ctx, args, stream);
break;
@@ -172,6 +181,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
mmq_supported = true;
break;
default:
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 608de8f0..c6e6a365 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -93,6 +93,9 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_IQ5_KS:
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ6_K:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
return MMQ_Q8_1_DS_LAYOUT_D4;
default:
GGML_ABORT("fatal error");
@@ -202,6 +205,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16;
+ case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ3_KT : return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0;
default : return tile_x_sizes{0, 0, 0};
}
}
@@ -250,6 +256,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K;
+ case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ3_KT : return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0;
default : return 0;
}
}
@@ -2790,6 +2799,226 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_kt(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + sizeof(float)) + kbx0;
+
+ int ib32 = kqsx/4;
+ int j = kqsx%4;
+ const auto shb = bxi->qs;
+ const auto ql = (const uint8_t *)(shb + 8);
+ const auto qh = ql + 64;
+ const uint32_t sh = shb[ib32] >> (8 + 6*j);
+ uint32_t offset = 4096 + ((shb[ib32] & 1) << 15);
+ uint32_t val1 = offset + ql[8*ib32+2*j+0] + ((qh[8*(ib32%4)+2*j+0] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 7) << 12);
+ uint32_t val2 = offset + ql[8*ib32+2*j+1] + ((qh[8*(ib32%4)+2*j+1] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 56) << 9);
+ int2 v = {0, 0};
+ for (int k = 0; k < 4; ++k) {
+ val1 *= ka;
+ val2 *= ka;
+ v.x |= (ggml_cuda_dp4a(val1 & km, 0x01010101, -126) & 0xff) << 8*k;
+ v.y |= (ggml_cuda_dp4a(val2 & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+ int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const float * dptr = (const float *)(x + i*stride);
+ const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 1) + kbx0;
+ const int ls = (bxi->qs[threadIdx.x % 8] & 0xff) >> 1;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * (ls - 64);
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * (ls - 64);
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_kt(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq2_kt * bxi = (const block_iq2_kt *)(x + i*stride + sizeof(float)) + kbx0;
+
+ int ib32 = kqsx/4;
+ int j = kqsx%4;
+ const auto ql = (const uint16_t *)bxi->ql;
+ uint32_t val = ql[4*ib32+j] + 4096;
+ int2 v = {0, 0};
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
+ }
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+ int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const float * dptr = (const float *)(x + i*stride);
+ const float d = dptr[0] * 1.05f;
+ const block_iq2_kt * bxi = (const block_iq2_kt *)(dptr + 1) + kbx0;
+ int ib32 = threadIdx.x % 8;
+ const int ls = iq4k_values[(bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf];
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_kt(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+ constexpr uint32_t ka = 0xCBAC1FED;
+ constexpr uint32_t km = 0x3f3f3f3f;
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq3_kt * bxi = (const block_iq3_kt *)(x + i*stride + sizeof(float)) + kbx0;
+
+ int ib32 = kqsx/4;
+ int j = kqsx%4;
+ const auto ql = (const uint16_t *)bxi->ql;
+ const auto qh = (const uint32_t *)bxi->qh;
+ uint32_t mask = 0x01010101 << ib32;
+ uint32_t val = ql[4*ib32+j] + 4096;
+ int2 v = {0, 0};
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v.x |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k;
+ }
+ auto signs = __vcmpne4(qh[2*j+0] & mask, 0);
+ v.x = __vsub4(v.x ^ signs, signs);
+ for (int k = 0; k < 4; ++k) {
+ val *= ka;
+ v.y |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k;
+ }
+ signs = __vcmpne4(qh[2*j+1] & mask, 0);
+ v.y = __vsub4(v.y ^ signs, signs);
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+ int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const float * dptr = (const float *)(x + i*stride);
+ const float d = dptr[0] * 1.01f;
+ const block_iq3_kt * bxi = (const block_iq3_kt *)(dptr + 1) + kbx0;
+ int ib32 = threadIdx.x % 8;
+ const int ls = (bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_ks_r4(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
@@ -3383,6 +3612,27 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS_R4> {
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KT> {
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_kt<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_KT> {
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_kt<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_KT> {
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_kt<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_KS> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3843,6 +4093,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KT);
// -------------------------------------------------------------------------------------------------------------------------
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 73caabab..6412be30 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -527,6 +527,15 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
case GGML_TYPE_IQ4_KSS:
mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
+ case GGML_TYPE_IQ2_KT:
+ mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+ break;
+ case GGML_TYPE_IQ3_KT:
+ mul_mat_vec_iq3_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+ break;
+ case GGML_TYPE_IQ4_KT:
+ mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
+ break;
case GGML_TYPE_IQ2_KS:
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
@@ -687,6 +696,9 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ1_M_R4:
+ case GGML_TYPE_IQ2_KT:
+ case GGML_TYPE_IQ3_KT:
+ case GGML_TYPE_IQ4_KT:
return true;
default:
return false;
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu
new file mode 100644
index 00000000..2d48f077
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_KT);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu
new file mode 100644
index 00000000..978bc6ca
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ3_KT);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu
new file mode 100644
index 00000000..4590f919
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_KT);
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index a05a890e..c3c4f0bb 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -6596,6 +6596,37 @@ void kernel_mul_mv_iq2_k_f32_impl(
}
}
+struct Trellis3 {
+ constexpr constant static uint32_t kmask = 0x3f3f3f3f;
+ constexpr constant static uint32_t ka = 89226354;
+ constexpr constant static uint32_t kb = 64248484;
+ constexpr constant static uint32_t ka1 = ka*ka;
+ constexpr constant static uint32_t kb1 = kb*ka+kb;
+ constexpr constant static uint32_t ka2 = ka1*ka;
+ constexpr constant static uint32_t kb2 = kb1*ka+kb;
+ constexpr constant static uint32_t ka3 = ka2*ka;
+ constexpr constant static uint32_t kb3 = kb2*ka+kb;
+ static inline char4 gen4(uint32_t val) {
+ thread uint32_t aux[4] = {(ka*val + kb) & kmask, (ka1*val + kb1) & kmask, (ka2*val + kb2) & kmask, (ka3*val + kb3) & kmask};
+ thread const int8_t * a8 = (thread const int8_t *)aux;
+ char4 result;
+ for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3];
+ return result;
+ }
+ template <typename T4>
+ static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) {
+ thread uint32_t aux[4] = {ka*val + kb, ka1*val + kb1, ka2*val + kb2, ka3*val + kb3};
+ uint32_t aux32[2];
+ thread const int8_t * a8 = (thread const int8_t *)aux32;
+ for (int i = 0; i < 4; ++i) {
+ aux32[0] = aux[i] & kmask;
+ aux32[1] = (ka3*aux[i] + kb3) & kmask;
+ v1[i] = -126 + a8[0] + a8[1] + a8[2] + a8[3];
+ v2[i] = -126 + a8[4] + a8[5] + a8[6] + a8[7];
+ }
+ }
+};
+
struct Trellis {
constexpr constant static uint32_t kmask1 = 0x8fff8fff;
constexpr constant static uint32_t kmask2 = 0x3b603b60;
@@ -6691,7 +6722,7 @@ void kernel_mul_mv_iq2_kt_f32_impl(
float drow[N_DST];
for (int row = 0; row < N_DST; ++row) {
device const float * dptr = (device const float *)(cx + row*row_size);
- drow[row] = dptr[0] * 31.75f * 1.05f;
+ drow[row] = dptr[0] * 1.05f;
}
device const block_iq2_kt * x = (device const block_iq2_kt *)(cx + sizeof(float));
@@ -6706,10 +6737,10 @@ void kernel_mul_mv_iq2_kt_f32_impl(
const float ls = drow[row] * iq4k_values[(sc[(it/2)%4] >> 4*(it/8)) & 0xf];
- Trellis::gen8(q2[2*it+0]+4096, v1, v2);
+ Trellis3::gen8(q2[2*it+0]+4096, v1, v2);
auto sum = v1*y4[0] + v2*y4[1];
- Trellis::gen8(q2[2*it+1]+4096, v1, v2);
+ Trellis3::gen8(q2[2*it+1]+4096, v1, v2);
sum += v1*y4[2] + v2*y4[3];
sum *= ls;
@@ -8542,19 +8573,18 @@ template <typename type4x4>
void dequantize_iq2_kt(device const block_iq2_kt * x, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256
int ib32 = il/2;
- half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 31.75h * 1.05h;
+ half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 1.05h;
device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2);
- half4 v1, v2;
+ char4 v1, v2;
for (int i = 0; i < 2; ++i) {
- Trellis::gen8(q2[i]+4096, v1, v2);
- v1 *= scale; v2 *= scale;
+ Trellis3::gen8(q2[i]+4096, v1, v2);
if constexpr (is_same_v<type4x4, half4x4>) {
- reg[2*i+0] = v1;
- reg[2*i+1] = v2;
+ reg[2*i+0] = {scale*(half)v1[0], scale*(half)v1[1], scale*(half)v1[2], scale*(half)v1[3]};
+ reg[2*i+1] = {scale*(half)v2[0], scale*(half)v2[1], scale*(half)v2[2], scale*(half)v2[3]};
} else {
- reg[2*i+0] = {(float)v1[0], (float)v1[1], (float)v1[2], (float)v1[3]};
- reg[2*i+1] = {(float)v2[0], (float)v2[1], (float)v2[2], (float)v2[3]};
+ reg[2*i+0] = {scale*(float)v1[0], scale*(float)v1[1], scale*(float)v1[2], scale*(float)v1[3]};
+ reg[2*i+1] = {scale*(float)v2[0], scale*(float)v2[1], scale*(float)v2[2], scale*(float)v2[3]};
}
}
}
@@ -8586,20 +8616,20 @@ void dequantize_iq4_kt(device const block_iq4_kt * x, short il, float d, thread
device const uint32_t * shb = x->qs;
device const uint8_t * ql = (device const uint8_t *)(shb + 8);
device const uint8_t * qh = ql + 64;
- float scale = d * (((shb[ib32] & 0xff) >> 1) - 64);
+ const int ls = (shb[ib32] & 0xff) >> 1;
+ const float scale = d * (ls - 64);
const uint32_t offset = 4096 + ((shb[ib32] & 1) << 15);
- const int jj = ib32*8 + 4*(il%2);
- ql += jj;
- qh += jj%32;
+ ql += 8*ib32;
+ qh += 8*(ib32%4);
uint32_t sh = (shb[ib32] >> (8 + 12*(il%2))) << 12;
- const int shift = 8 - 4*(jj/32);
+ const int shift = 8 - 4*(ib32/4);
for (int i = 0; i < 4; ++i) {
uint32_t idx = ql[i] + ((qh[i] << shift) & 0xf00) + ((sh >> 3*i) & 0x7000) + offset;
- auto v = (float4)Trellis::gen4(idx);
- reg[i] = v * scale;
+ auto c4 = Trellis3::gen4(idx);
+ reg[i] = {scale*c4[0], scale*c4[1], scale*c4[2], scale*c4[3]};
}
}
@@ -8931,18 +8961,17 @@ struct DequantizerKT4 {
using type4x4 = T4x4;
DequantizerKT4(device const char * cx, short il = 0) : il(il) {
device const float * dptr = (device const float *)cx;
- d[0] = dptr[0] * 31.75f * 1.01f;
- d[1] = dptr[1];
- x = (device const Block *)(dptr + 2);
+ d = dptr[0] * 1.01f;
+ x = (device const Block *)(dptr + 1);
}
inline void convert(thread T4x4& t) const {
float4x4 tmp;
- dequantize_iq4_kt(x, il, d[0], tmp);
+ dequantize_iq4_kt(x, il, d, tmp);
for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j];
}
inline void convert(int64_t ind, thread T4x4& t) {
float4x4 tmp;
- dequantize_iq4_kt(x + ind/nl, ind%nl, d[0], tmp);
+ dequantize_iq4_kt(x + ind/nl, ind%nl, d, tmp);
for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j];
}
inline void next() {
@@ -8951,7 +8980,7 @@ struct DequantizerKT4 {
}
device const Block * x;
short il;
- float d[2];
+ float d;
};
template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(half d, device const Block *, short, thread T4x4&), bool may_not_be_aligned = false>
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 69b1b46d..cc056f89 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1596,10 +1596,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq2_kt,
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_kt_ref,
.vec_dot = vec_dot_iq2_kt_q8_k,
-#ifdef __ARM_NEON
- .vec_dot_type = GGML_TYPE_F16,
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
- .vec_dot_type = GGML_TYPE_F32,
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
.nrows = 1,
.row_meta_size = 4,
@@ -1613,11 +1613,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq3_kt,
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_kt_ref,
.vec_dot = vec_dot_iq3_kt_q8_k,
-#ifdef __ARM_NEON
- .vec_dot_type = GGML_TYPE_F16,
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
- .vec_dot_type = GGML_TYPE_F32,
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
+//#ifdef __ARM_NEON
+// .vec_dot_type = GGML_TYPE_F16,
+//#else
+// .vec_dot_type = GGML_TYPE_F32,
+//#endif
.nrows = 1,
.row_meta_size = 4,
},
@@ -1630,13 +1635,13 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq4_kt,
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref,
.vec_dot = vec_dot_iq4_kt_q8_k,
-#ifdef __ARM_NEON
- .vec_dot_type = GGML_TYPE_F16,
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
- .vec_dot_type = GGML_TYPE_F32,
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
.nrows = 1,
- .row_meta_size = 8,
+ .row_meta_size = 4,
},
[GGML_TYPE_IQ3_K] = {
.type_name = "iq3_k",
diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp
index bc7bcf8b..2ddfbe86 100644
--- a/ggml/src/iqk/iqk_gemm_ktquants.cpp
+++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp
@@ -97,6 +97,134 @@ struct Trellis2 {
}
};
+
+template <bool is_8 = false, bool is_abs = false>
+struct Trellis3 {
+ constexpr static uint32_t ka = 0xCBAC1FED;
+ constexpr static uint32_t ka1 = ka*ka;
+ constexpr static uint32_t ka2 = ka1*ka;
+ constexpr static uint32_t ka3 = ka2*ka;
+ constexpr static uint32_t ka4 = ka3*ka;
+ constexpr static uint32_t ka5 = ka4*ka;
+ constexpr static uint32_t ka6 = ka5*ka;
+ constexpr static uint32_t ka7 = ka6*ka;
+ const __m256i mka = is_8 ? _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7) : _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3);
+ const __m256i shuffle = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
+
+ inline __m256i next8(uint32_t val1, uint32_t val2) const {
+ __m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1));
+ return _mm256_mullo_epi32(mval, mka);
+ }
+ inline __m256i next8(uint32_t val) const {
+ __m256i mval = _mm256_set1_epi32(val);
+ return _mm256_mullo_epi32(mval, mka);
+ }
+ inline __m256 gen8(uint32_t val1, uint32_t val2) const {
+ auto v8 = _mm256_and_si256(next8(val1, val2), _mm256_set1_epi32(0x3f3f3f3f));
+#ifdef HAVE_FANCY_SIMD
+ auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8);
+#else
+ auto dot = _mm256_maddubs_epi16(v8, _mm256_set1_epi32(0x01010101));
+ auto i8 = _mm256_add_epi32(_mm256_set1_epi32(-126), _mm256_madd_epi16(dot, _mm256_set1_epi16(1)));
+#endif
+ if constexpr (is_abs) {
+ return _mm256_cvtepi32_ps(_mm256_sign_epi32(i8, i8));
+ } else {
+ return _mm256_cvtepi32_ps(i8);
+ }
+ }
+ inline __m256 gen8(uint32_t val) const {
+ auto v8 = _mm256_and_si256(next8(val), _mm256_set1_epi32(0x3f3f3f3f));
+#ifdef HAVE_FANCY_SIMD
+ auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8);
+#else
+ auto dot = _mm256_maddubs_epi16(v8, _mm256_set1_epi32(0x01010101));
+ auto i8 = _mm256_add_epi32(_mm256_set1_epi32(-126), _mm256_madd_epi16(dot, _mm256_set1_epi16(1)));
+#endif
+ if constexpr (is_abs) {
+ return _mm256_cvtepi32_ps(_mm256_sign_epi32(i8, i8));
+ } else {
+ return _mm256_cvtepi32_ps(i8);
+ }
+ }
+ inline __m256i next32(const uint32_t * val) const {
+ const __m256i offset = _mm256_set1_epi32(-126);
+ __m256i aux[4];
+ for (int i = 0; i < 4; ++i) {
+ auto i8 = _mm256_and_si256(next8(val[2*i+0], val[2*i+1]), _mm256_set1_epi32(0x3f3f3f3f));
+#ifdef HAVE_FANCY_SIMD
+ aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8);
+#else
+ auto dot = _mm256_maddubs_epi16(i8, _mm256_set1_epi32(0x01010101));
+ aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1)));
+#endif
+ }
+ aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ aux[0] = _mm256_packs_epi16(aux[0], aux[2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
+ // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+ if constexpr (is_abs) {
+ auto result = _mm256_permutevar8x32_epi32(aux[0], shuffle);
+ return _mm256_sign_epi8(result, result);
+ } else {
+ return _mm256_permutevar8x32_epi32(aux[0], shuffle);
+ }
+ }
+ inline __m256i next32(const uint16_t * val, uint32_t v0) const {
+ const __m256i offset = _mm256_set1_epi32(-126);
+ __m256i aux[4];
+ for (int i = 0; i < 4; ++i) {
+ auto i8 = _mm256_and_si256(next8(v0 + val[i]), _mm256_set1_epi32(0x3f3f3f3f));
+#ifdef HAVE_FANCY_SIMD
+ aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8);
+#else
+ auto dot = _mm256_maddubs_epi16(i8, _mm256_set1_epi32(0x01010101));
+ aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1)));
+#endif
+ }
+ aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ aux[0] = _mm256_packs_epi16(aux[0], aux[2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
+ // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+ if constexpr (is_abs) {
+ auto result = _mm256_permutevar8x32_epi32(aux[0], shuffle);
+ return _mm256_sign_epi8(result, result);
+ } else {
+ return _mm256_permutevar8x32_epi32(aux[0], shuffle);
+ }
+ }
+ inline void next64(const uint32_t * val, __m256i * result) const {
+ const __m256i offset = _mm256_set1_epi32(-126);
+ auto vka3 = _mm256_set1_epi32(ka3);
+ __m256i aux[8];
+ for (int i = 0; i < 4; ++i) {
+ auto i8_1 = next8(val[2*i+0], val[2*i+1]);
+ auto i8_2 = _mm256_mullo_epi32(i8_1, vka3);
+ i8_1 = _mm256_and_si256(i8_1, _mm256_set1_epi32(0x3f3f3f3f));
+ i8_2 = _mm256_and_si256(i8_2, _mm256_set1_epi32(0x3f3f3f3f));
+#ifdef HAVE_FANCY_SIMD
+ aux[i+0] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8_1);
+ aux[i+4] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8_2);
+#else
+ auto dot1 = _mm256_maddubs_epi16(i8_1, _mm256_set1_epi32(0x01010101));
+ auto dot2 = _mm256_maddubs_epi16(i8_2, _mm256_set1_epi32(0x01010101));
+ aux[i+0] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot1, _mm256_set1_epi16(1)));
+ aux[i+4] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot2, _mm256_set1_epi16(1)));
+#endif
+ }
+ for (int k = 0; k < 2; ++k) {
+ aux[4*k+0] = _mm256_packs_epi32(aux[4*k+0], aux[4*k+1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ aux[4*k+2] = _mm256_packs_epi32(aux[4*k+2], aux[4*k+3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ aux[4*k+0] = _mm256_packs_epi16(aux[4*k+0], aux[4*k+2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
+ // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+ result[k] = _mm256_permutevar8x32_epi32(aux[4*k+0], shuffle);
+ if constexpr (is_abs) {
+ result[k] = _mm256_sign_epi8(result[k], result[k]);
+ }
+ }
+ }
+};
+
void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
const int nb = n/QK_K;
@@ -136,6 +264,60 @@ void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t
}
}
+void iqk_dequantize_iq2_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+ const int nb = n/QK_K;
+
+ Trellis3 trellis;
+
+ auto shifts = _mm_set_epi32(0, 0, 4, 0);
+ auto values = _mm_loadu_si128((const __m128i *)iq4k_values);
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ const block_iq2_kt * x8[8];
+ float dkt[8];
+ float ls[8];
+ float ls_all[64];
+ uint32_t idx[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) {
+ const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
+ dkt[k] = dptr[0];
+ x8[k] = (const block_iq2_kt *)(dptr + 1);
+ }
+ auto vd = _mm256_mul_ps(_mm256_set1_ps(1.05f), _mm256_loadu_ps(dkt));
+
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ auto s8 = _mm_set1_epi32(*(const uint32_t *)x8[k][i].scales);
+ s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
+ s8 = _mm_shuffle_epi8(values, s8);
+ auto s32 = _mm256_cvtepi8_epi32(s8);
+ _mm256_storeu_ps(ls_all + 8*k, _mm256_cvtepi32_ps(s32));
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib];
+ auto scales = _mm256_mul_ps(vd, _mm256_loadu_ps(ls));
+ _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 8; ++k) {
+ const uint16_t * ql = (const uint16_t *)x8[k][i].ql;
+ idx[k] = ql[4*ib+j] + 4096;
+ }
+ __m256i packed[2];
+ trellis.next64(idx, packed);
+ _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, packed[0]);
+ _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, packed[1]);
+ }
+ }
+ y += 8; // = QK_K/32;
+ }
+ }
+}
+
template <int nrc_y>
void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
@@ -198,6 +380,243 @@ void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
+template <int nrc_y>
+void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+
+ Trellis3<true> trellis;
+
+ auto shifts = _mm_set_epi32(0, 0, 4, 0);
+ auto values = _mm_loadu_si128((const __m128i *)iq4k_values);
+
+ constexpr int k_acc = nrc_y;
+
+ __m256 accd[k_acc];
+ const block_q8_2_x4 * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ y[iy] = (const block_q8_2_x4 *)info.src1_row(iy);
+ }
+
+ __m256i xv[4], dot[4];
+ __m256 scales[2];
+
+ auto sum_4 = [&dot] () {
+ // dot[k] has 8 values from block k
+ // 0 1 0 1 0 1 0 1
+ dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1]));
+ // 2 3 2 3 2 3 2 3
+ dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3]));
+ // 0 1 2 3 0 1 2 3
+ dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2]));
+ return _mm256_cvtepi32_ps(dot[0]);
+ };
+
+ auto compute_dot = [&dot, &xv] (const int8_t * y) {
+ for (int k = 0; k < 4; ++k) {
+ auto yv = _mm256_loadu_si256((const __m256i *)y + k);
+#ifdef HAVE_FANCY_SIMD
+ //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
+ dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k]));
+#else
+ auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k]));
+ dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1));
+#endif
+ }
+ };
+
+ //auto m126 = _mm256_set1_ps(-126.f);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ const float * dptr = (const float *)((const char*)vx + ix*bx);
+ auto d = _mm256_set1_ps(dptr[0] * 1.05f);
+ const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
+
+ for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
+ s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
+ s8 = _mm_shuffle_epi8(values, s8);
+ auto s32 = _mm256_cvtepi8_epi32(s8);
+ auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32));
+ auto scales_l = _mm256_castps256_ps128(all_scales);
+ auto scales_h = _mm256_extractf128_ps(all_scales, 1);
+ scales[0] = _mm256_set_m128(scales_l, scales_l);
+ scales[1] = _mm256_set_m128(scales_h, scales_h);
+ for (int i128 = 0; i128 < 2; ++i128) {
+ //for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
+ for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const block_q8_2_x4& yb = y[iy][2*i+i128];
+ auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16));
+ dy = _mm256_mul_ps(scales[i128], dy);
+ auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
+ //auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
+ compute_dot(yb.qs);
+ accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]);
+ //accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
+ }
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+ }
+}
+
+void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+ const int nb = n/QK_K;
+
+ Trellis3<false, true> trellis;
+
+ auto shifts = _mm_set_epi32(0, 0, 4, 0);
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ const block_iq3_kt * x8[8];
+ float dkt[8];
+ float ls[8];
+ float ls_all[64];
+ uint32_t idx[8];
+ uint32_t sign_bits[16];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) {
+ const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
+ dkt[k] = dptr[0];
+ x8[k] = (const block_iq3_kt *)(dptr + 1);
+ }
+ auto vd = _mm256_mul_ps(_mm256_set1_ps(1.01f), _mm256_loadu_ps(dkt));
+
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ auto s8 = _mm_set1_epi32(*(const uint32_t *)x8[k][i].scales);
+ s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
+ auto s32 = _mm256_cvtepi8_epi32(s8);
+ _mm256_storeu_ps(ls_all + 8*k, _mm256_cvtepi32_ps(s32));
+ }
+ auto mask = _mm256_set1_epi8(1);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib];
+ auto scales = _mm256_mul_ps(vd, _mm256_loadu_ps(ls));
+ _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 8; ++k) {
+ const uint16_t * ql = (const uint16_t *)x8[k][i].ql;
+ idx[k] = ql[4*ib+j] + 4096;
+ auto qh = (const uint32_t *)x8[k][i].qh;
+ sign_bits[k+0] = qh[2*j+0];
+ sign_bits[k+8] = qh[2*j+1];
+ }
+ __m256i packed[2];
+ trellis.next64(idx, packed);
+ auto signs1 = _mm256_loadu_si256((const __m256i *)sign_bits+0);
+ auto signs2 = _mm256_loadu_si256((const __m256i *)sign_bits+1);
+ signs1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs1, mask), mask), _mm256_set1_epi8(1));
+ signs2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs2, mask), mask), _mm256_set1_epi8(1));
+ packed[0] = _mm256_sign_epi8(packed[0], signs1);
+ packed[1] = _mm256_sign_epi8(packed[1], signs2);
+ _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, packed[0]);
+ _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, packed[1]);
+ }
+ mask = _mm256_slli_epi16(mask, 1);
+ }
+ y += 8; // = QK_K/32;
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq3_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+
+ Trellis3<true, true> trellis;
+
+ auto shifts = _mm_set_epi32(0, 0, 4, 0);
+
+ constexpr int k_acc = nrc_y;
+
+ __m256 accd[k_acc];
+ const block_q8_2_x4 * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ y[iy] = (const block_q8_2_x4 *)info.src1_row(iy);
+ }
+
+ __m256i xv[4], sv[4], dot[4];
+ __m256 scales[2];
+
+ auto sum_4 = [&dot] () {
+ // dot[k] has 8 values from block k
+ // 0 1 0 1 0 1 0 1
+ dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1]));
+ // 2 3 2 3 2 3 2 3
+ dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3]));
+ // 0 1 2 3 0 1 2 3
+ dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2]));
+ return _mm256_cvtepi32_ps(dot[0]);
+ };
+
+ auto compute_dot = [&dot, &xv, &sv] (const int8_t * y) {
+ for (int k = 0; k < 4; ++k) {
+ auto yv = _mm256_loadu_si256((const __m256i *)y + k);
+#ifdef HAVE_FANCY_SIMD
+ //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
+ dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], _mm256_sign_epi8(yv, sv[k]));
+#else
+ auto p = _mm256_maddubs_epi16(xv[k], _mm256_sign_epi8(yv, sv[k]));
+ dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1));
+#endif
+ }
+ };
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ const float * dptr = (const float *)((const char*)vx + ix*bx);
+ auto d = _mm256_set1_ps(dptr[0] * 1.01f);
+ const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
+
+ for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+ auto ql = (const uint16_t *)x[i].ql;
+ auto sign_bits = _mm256_loadu_si256((const __m256i *)x[i].qh);
+ auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
+ s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
+ auto s32 = _mm256_cvtepi8_epi32(s8);
+ auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32));
+ auto scales_l = _mm256_castps256_ps128(all_scales);
+ auto scales_h = _mm256_extractf128_ps(all_scales, 1);
+ scales[0] = _mm256_set_m128(scales_l, scales_l);
+ scales[1] = _mm256_set_m128(scales_h, scales_h);
+ auto mask = _mm256_set1_epi8(1);
+ for (int i128 = 0; i128 < 2; ++i128) {
+ for (int k = 0; k < 4; ++k) {
+ xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096);
+ sv[k] = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sign_bits, mask), mask), _mm256_set1_epi8(1));
+ mask = _mm256_slli_epi16(mask, 1);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const block_q8_2_x4& yb = y[iy][2*i+i128];
+ auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16));
+ dy = _mm256_mul_ps(scales[i128], dy);
+ auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
+ compute_dot(yb.qs);
+ accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]);
+ }
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+ }
+}
+
inline __m256 abs_ps(__m256 vals) {
// Clear sign-bit of all the 32-bit floats in vals
__m256 sign_bit = _mm256_set1_ps(-0.0f);
@@ -315,19 +734,67 @@ void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
+void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+ const int nb = n/QK_K;
+ constexpr int kNumGroups = 64;
+
+ Trellis3 trellis;
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ const block_iq4_kt * x8[8];
+ float dkt[8];
+ int32_t ls[8];
+ uint32_t idx0[8], idx[16];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) {
+ const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
+ dkt[k] = dptr[0];
+ x8[k] = (const block_iq4_kt *)(dptr + 1);
+ }
+ auto vd = _mm256_loadu_ps(dkt);
+
+ for (int i = 0; i < nb; ++i) {
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ for (int k = 0; k < 8; ++k) {
+ ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64;
+ idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096;
+ }
+ auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls)));
+ _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
+ int shift1 = 8 - 4*(ib/4);
+ for (int j = 0; j < 8; ++j) {
+ for (int k = 0; k < 8; ++k) {
+ const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8);
+ const uint8_t * qh = ql + kNumGroups;
+ const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j);
+ idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k];
+ }
+ _mm256_storeu_si256((__m256i *)y[ib].qs+j, trellis.next32(idx));
+ }
+ }
+ y += 8; // = QK_K/32;
+ }
+
+ }
+}
+
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
const int nb = n/QK_K;
constexpr int kNumGroups = 64;
- Trellis2 trellis;
+ Trellis3 trellis;
union { __m256 vec; float val[8]; } s_helper;
union { __m256i vec; uint32_t val[8]; } o_helper;
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
- auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f);
+ auto d = _mm256_set1_ps(dptr[0]);
auto dav = _mm256_set1_ps(dptr[1]);
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
@@ -349,8 +816,8 @@ void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
- auto x_val1 = _mm256_fmadd_ps(scale1, trellis_gen8(trellis.next8(val1, val3)), dav);
- auto x_val2 = _mm256_fmadd_ps(scale2, trellis_gen8(trellis.next8(val2, val4)), dav);
+ auto x_val1 = _mm256_fmadd_ps(scale1, trellis.gen8(val1, val3), dav);
+ auto x_val2 = _mm256_fmadd_ps(scale2, trellis.gen8(val2, val4), dav);
_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1);
_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2);
@@ -365,12 +832,112 @@ void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t
}
template <int nrc_y>
+void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+ constexpr int kNumGroups = 64;
+
+ Trellis3 trellis;
+
+ union { __m256i vec; uint32_t val[8]; } o_helper;
+
+ constexpr int k_acc = nrc_y;
+
+ __m256 accd[k_acc];
+ const block_q8_2_x4 * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ y[iy] = (const block_q8_2_x4 *)info.src1_row(iy);
+ }
+
+ uint32_t values[64];
+ __m256i xv[4], dot[4];
+ __m256 scales[2];
+
+ auto sum_4 = [&dot] () {
+ // dot[k] has 8 values from block k
+ // 0 1 0 1 0 1 0 1
+ dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1]));
+ // 2 3 2 3 2 3 2 3
+ dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3]));
+ // 0 1 2 3 0 1 2 3
+ dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2]));
+ return _mm256_cvtepi32_ps(dot[0]);
+ };
+
+ auto compute_dot = [&dot, &xv] (const int8_t * y) {
+ for (int k = 0; k < 4; ++k) {
+ auto yv = _mm256_loadu_si256((const __m256i *)y + k);
+#ifdef HAVE_FANCY_SIMD
+ //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
+ dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k]));
+#else
+ auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k]));
+ dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1));
+#endif
+ }
+ };
+
+ //auto m126 = _mm256_set1_ps(-126.f);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ const float * dptr = (const float *)((const char*)vx + ix*bx);
+ auto d = _mm256_set1_ps(dptr[0]);
+ const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1);
+
+ for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+ auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs);
+ const uint32_t * shb = x[i].qs;
+ const uint8_t * ql = (const uint8_t *)(shb + 8);
+ const uint8_t * qh = ql + kNumGroups;
+ auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1);
+ iscales = _mm256_sub_epi32(iscales, _mm256_set1_epi32(64));
+ auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(iscales));
+ auto scales_l = _mm256_castps256_ps128(all_scales);
+ auto scales_h = _mm256_extractf128_ps(all_scales, 1);
+ scales[0] = _mm256_set_m128(scales_l, scales_l);
+ scales[1] = _mm256_set_m128(scales_h, scales_h);
+ o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
+ for (int ib = 0; ib < 4; ++ib) {
+ for (int j = 0; j < 4; ++j) {
+ const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
+ const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
+ values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
+ values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
+ values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
+ values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
+ }
+ }
+ for (int i128 = 0; i128 < 2; ++i128) {
+ //for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
+ for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const block_q8_2_x4& yb = y[iy][2*i+i128];
+ auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16));
+ dy = _mm256_mul_ps(scales[i128], dy);
+ auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
+ //auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
+ compute_dot(yb.qs);
+ accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]);
+ //accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
+ }
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+ }
+}
+
+template <int nrc_y>
void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
const int nb = n/QK_K;
constexpr int kNumGroups = 64;
- Trellis2 trellis;
+ Trellis3 trellis;
union { __m256 vec; float val[8]; } s_helper;
union { __m256i vec; uint32_t val[8]; } o_helper;
@@ -389,7 +956,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
- auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f);
+ auto d = _mm256_set1_ps(dptr[0]);
auto dav = dptr[1];
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
@@ -413,8 +980,8 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
- auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3)));
- auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4)));
+ auto x_val1 = _mm256_mul_ps(scale1, trellis.gen8(val1, val3));
+ auto x_val2 = _mm256_mul_ps(scale2, trellis.gen8(val2, val4));
if constexpr (nrc_y == 1) {
auto y1 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+ 0);
auto y2 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+128);
@@ -446,11 +1013,37 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
- if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F32) {
+ if (ne00%QK_K != 0) return false;
+
+ func16 = nullptr;
+
+ if (typeA == GGML_TYPE_IQ4_KT) {
+ if (typeB == GGML_TYPE_Q8_2_X4) {
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_2_x4_T, kernels);
+ return true;
+ }
return false;
}
- func16 = nullptr;
+ if (typeA == GGML_TYPE_IQ2_KT) {
+ if (typeB == GGML_TYPE_Q8_2_X4) {
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_q8_2_x4_T, kernels);
+ return true;
+ }
+ return false;
+ }
+
+ if (typeA == GGML_TYPE_IQ3_KT) {
+ if (typeB == GGML_TYPE_Q8_2_X4) {
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_q8_2_x4_T, kernels);
+ return true;
+ }
+ return false;
+ }
+
+ if (ggml_type(typeB) != GGML_TYPE_F32) {
+ return false;
+ }
switch (typeA) {
case GGML_TYPE_IQ2_KT:
@@ -470,11 +1063,11 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
}
-bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) {
+bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, [[maybe_unused]] size_t stride_y, int nrc_x) {
switch (type) {
- case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
- case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
- case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
+ case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break;
+ case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt_q80_r8(n, vx, bx, y, nrc_x); break;
+ case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break;
default: return false;
}
return true;
@@ -992,21 +1585,423 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
+struct Trellis3 {
+ constexpr static uint32_t ka = ;0xCBAC1FED;
+ constexpr static uint32_t ka1 = ka*ka;
+ constexpr static uint32_t ka2 = ka1*ka;
+ constexpr static uint32_t ka3 = ka2*ka;
+ const uint32x4_t mka = uint32x4_t{ka, ka1, ka2, ka3};
+ const uint8x16_t shuffle = load_shuffle();
+
+ inline uint32x4x2_t next8(uint32_t val1, uint32_t val2) const {
+ uint32x4x2_t result{vdupq_n_u32(val1), vdupq_n_u32(val2)};
+ result.val[0] = vmulq_u32(mka, result.val[0]);
+ result.val[1] = vmulq_u32(mka, result.val[1]);
+ return result;
+ }
+ inline int8x16x2_t next32(const uint32_t * val) const {
+ int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)};
+ for (int i = 0; i < 2; ++i) {
+ auto i8 = next8(val[4*i+0], val[4*i+1]);
+ i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
+ i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
+ auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1]));
+ i8 = next8(val[4*i+2], val[4*i+3]);
+ i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
+ i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
+ auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1]));
+ result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2));
+ }
+ return result;
+ }
+ inline int8x16x2_t next32(const uint16_t * val, uint32_t v0) const {
+ auto vka3 = vdupq_n_u32(ka3), vkb3 = vdupq_n_u32(kb3);
+ int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)};
+ int8x16x2_t i8;
+ for (int i = 0; i < 2; ++i) {
+ i8.val[0] = vmulq_u32(mka, vdupq_n_u32(val[2*i+0]+v0));
+ i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]);
+ i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
+ i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
+ auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1]));
+ i8.val[0] = vmulq_u32(mka, vdupq_n_u32(val[2*i+1]+v0));
+ i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]);
+ i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
+ i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
+ auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1]));
+ result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2));
+ }
+ return result;
+ }
+ inline int8x16x4_t next64(const uint32_t * val) const {
+ auto vka3 = vdupq_n_u32(ka3), vkb3 = vdupq_n_u32(kb3);
+ int8x16x4_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126), vdupq_n_s8(-126), vdupq_n_s8(-126)};
+ for (int i = 0; i < 2; ++i) {
+ auto i8_1 = next8(val[4*i+0], val[4*i+1]);
+ int8x16x2_t i8_2{vmlaq_u32(vkb3, vka3, i8_1.val[0]), vmlaq_u32(vkb3, vka3, i8_1.val[1])};
+ i8_1.val[0] = vandq_u32(i8_1.val[0], vdupq_n_u32(0x3f3f3f3f));
+ i8_1.val[1] = vandq_u32(i8_1.val[1], vdupq_n_u32(0x3f3f3f3f));
+ i8_2.val[0] = vandq_u32(i8_2.val[0], vdupq_n_u32(0x3f3f3f3f));
+ i8_2.val[1] = vandq_u32(i8_2.val[1], vdupq_n_u32(0x3f3f3f3f));
+ auto s1_1 = vpaddq_s8(vreinterpretq_s8_u32(i8_1.val[0]), vreinterpretq_s8_u32(i8_1.val[1]));
+ auto s1_2 = vpaddq_s8(vreinterpretq_s8_u32(i8_2.val[0]), vreinterpretq_s8_u32(i8_2.val[1]));
+ i8_1 = next8(val[4*i+2], val[4*i+3]);
+ i8_2.val[0] = vmlaq_u32(vkb3, vka3, i8_1.val[0]);
+ i8_2.val[1] = vmlaq_u32(vkb3, vka3, i8_1.val[1]);
+ i8_1.val[0] = vandq_u32(i8_1.val[0], vdupq_n_u32(0x3f3f3f3f));
+ i8_1.val[1] = vandq_u32(i8_1.val[1], vdupq_n_u32(0x3f3f3f3f));
+ i8_2.val[0] = vandq_u32(i8_2.val[0], vdupq_n_u32(0x3f3f3f3f));
+ i8_2.val[1] = vandq_u32(i8_2.val[1], vdupq_n_u32(0x3f3f3f3f));
+ auto s2_1 = vpaddq_s8(vreinterpretq_s8_u32(i8_1.val[0]), vreinterpretq_s8_u32(i8_1.val[1]));
+ auto s2_2 = vpaddq_s8(vreinterpretq_s8_u32(i8_2.val[0]), vreinterpretq_s8_u32(i8_2.val[1]));
+ result.val[i+0] = vaddq_s8(result.val[i+0], vpaddq_s8(s1_1, s2_1));
+ result.val[i+2] = vaddq_s8(result.val[i+2], vpaddq_s8(s1_2, s2_2));
+ }
+ return result;
+ }
+ static uint8x16_t load_shuffle() {
+ static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60};
+ return vld1q_u8(k_shuffle);
+ }
+};
+
+void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+ const int nb = n/QK_K;
+ constexpr int kNumGroups = 64;
+
+ Trellis3 trellis;
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ const block_iq4_kt * x8[8];
+ float dkt[8];
+ int32_t ls[8];
+ uint32_t idx0[8], idx[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) {
+ const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
+ dkt[k] = dptr[0];
+ x8[k] = (const block_iq4_kt *)(dptr + 1);
+ }
+ auto vd = vld1q_f32_x2(dkt);
+
+ for (int i = 0; i < nb; ++i) {
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ for (int k = 0; k < 8; ++k) {
+ ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64;
+ idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096;
+ }
+ auto scales1 = vmulq_f32(vd.val[0], vcvtq_f32_s32(vld1q_s32(ls+0)));
+ auto scales2 = vmulq_f32(vd.val[1], vcvtq_f32_s32(vld1q_s32(ls+4)));
+ vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1));
+ vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2));
+ int shift1 = 8 - 4*(ib/4);
+ for (int j = 0; j < 8; ++j) {
+ for (int k = 0; k < 8; ++k) {
+ const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8);
+ const uint8_t * qh = ql + kNumGroups;
+ const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j);
+ idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k];
+ }
+ vst1q_s8_x2(y[ib].qs+32*j, trellis.next32(idx));
+ }
+ }
+ y += 8; // = QK_K/32;
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+ constexpr int kNumGroups = 64;
+
+ Trellis3 trellis;
+
+ union { uint32x4x2_t vec; uint32_t val[8]; } o_helper;
+
+ constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
+
+ float32x4_t accd[k_acc];
+
+ const block_q8_0_x4 * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ y[iy] = (const block_q8_0_x4 *)info.src1_row(iy);
+ }
+
+ uint32_t values[16];
+ int8x16x2_t xv[8];
+ int32x4x4_t dot;
+
+ auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) {
+ for (int k = 0; k < 4; ++k) {
+ auto yv = vld1q_s8_x2(y + 32*k);
+ dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]);
+ }
+ dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]);
+ dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]);
+ return vpaddq_s32(dot.val[0], dot.val[2]);
+ };
+
+ int32x4x2_t shifts = {int32x4_t{4, 1, -2, -5}, int32x4_t{-8, -11, -14, -17}};
+
+ float32x4x2_t scales;
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ const float * dptr = (const float *)((const char*)vx + ix*bx);
+ auto d = vdupq_n_f32(dptr[0]);
+ const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1);
+
+ for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0);
+
+ for (int i = 0; i < nb; ++i) {
+ auto vshb = vld1q_u32_x2(x[i].qs);
+ const uint32_t * shb = x[i].qs;
+ const uint8_t * ql = (const uint8_t *)(shb + 8);
+ const uint8_t * qh = ql + kNumGroups;
+ auto iscales1 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(0xff)), 1));
+ auto iscales2 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(0xff)), 1));
+ iscales1 = vaddq_s32(iscales1, vdupq_n_s32(-64));
+ iscales2 = vaddq_s32(iscales2, vdupq_n_s32(-64));
+ scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(iscales1));
+ scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(iscales2));
+ o_helper.vec.val[0] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(1)), 15), vdupq_n_u32(4096));
+ o_helper.vec.val[1] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(1)), 15), vdupq_n_u32(4096));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto vql1 = vmovl_u8(vld1_u8(ql+8*ib));
+ auto vql2 = vmovl_u8(vld1_u8(ql+8*ib+32));
+ auto vqh = vmovl_u8(vld1_u8(qh+8*ib));
+ vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 8)));
+ vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 4)));
+ auto sh1_u32 = vdupq_n_u32(shb[ib+0]);
+ auto sh2_u32 = vdupq_n_u32(shb[ib+4]);
+ auto sh1 = vcombine_u16(vmovn_u32(vshlq_u32(sh1_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh1_u32, shifts.val[1])));
+ auto sh2 = vcombine_u16(vmovn_u32(vshlq_u32(sh2_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh2_u32, shifts.val[1])));
+ vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0x7000), sh1));
+ vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0x7000), sh2));
+ auto oh1 = vdupq_n_u32(o_helper.val[ib+0]);
+ auto oh2 = vdupq_n_u32(o_helper.val[ib+4]);
+ vst1q_u32(values +0, vaddq_u32(vmovl_u16(vget_low_u16 (vql1)), oh1));
+ vst1q_u32(values +4, vaddq_u32(vmovl_u16(vget_high_u16(vql1)), oh1));
+ vst1q_u32(values +8, vaddq_u32(vmovl_u16(vget_low_u16 (vql2)), oh2));
+ vst1q_u32(values+12, vaddq_u32(vmovl_u16(vget_high_u16(vql2)), oh2));
+ xv[ib+0] = trellis.next32(values+0);
+ xv[ib+4] = trellis.next32(values+8);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const block_q8_0_x4& ybl = y[iy][2*i+0];
+ const block_q8_0_x4& ybh = y[iy][2*i+1];
+ auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
+ auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
+ auto sumil = compute_dot(ybl.qs, xv+0);
+ auto sumih = compute_dot(ybh.qs, xv+4);
+ if constexpr (nrc_y == 1) {
+ accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil));
+ accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih));
+ } else {
+ accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil));
+ accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih));
+ }
+ }
+ }
+
+ if constexpr (nrc_y == 1) {
+ info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1])));
+ } else {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vaddvq_f32(accd[iy]));
+ }
+ }
+ }
+}
+
+void iqk_dequantize_iq2_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+ const int nb = n/QK_K;
+
+ Trellis3 trellis;
+
+ auto values = vld1q_s8(iq4k_values);
+
+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
+
+ const block_iq2_kt * x8[8];
+ float dkt[8];
+ float ls[8], ls_all[64];
+ uint32_t idx[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) {
+ const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
+ dkt[k] = dptr[0] * 1.05f;
+ x8[k] = (const block_iq2_kt *)(dptr + 1);
+ }
+ auto vd = vld1q_f32_x2(dkt);
+
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ auto u32 = *(const uint32_t *)x8[k][i].scales;
+ auto s8_u32 = uint32x2_t{u32, u32 >> 4};
+ s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f));
+ auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32));
+ auto s16 = vmovl_s8(s8);
+ vst1q_f32(ls_all + 8*k + 0, vcvtq_f32_s32(vmovl_s16(vget_low_s16(s16))));
+ vst1q_f32(ls_all + 8*k + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib];
+ auto scales1 = vmulq_f32(vd.val[0], vld1q_f32(ls+0));
+ auto scales2 = vmulq_f32(vd.val[1], vld1q_f32(ls+4));
+ vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1));
+ vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2));
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 8; ++k) {
+ const uint16_t * ql = (const uint16_t *)x8[k][i].ql;
+ idx[k] = ql[4*ib+j] + 4096;
+ }
+ vst1q_s8_x4(y[ib].qs+64*j, trellis.next64(idx));
+ }
+ }
+ y += 8; // = QK_K/32;
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+
+ Trellis3 trellis;
+
+ auto values = vld1q_s8(iq4k_values);
+
+ constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
+
+ float32x4_t accd[k_acc];
+
+ const block_q8_0_x4 * y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ y[iy] = (const block_q8_0_x4 *)info.src1_row(iy);
+ }
+
+ int8x16x2_t xv[8];
+ int32x4x4_t dot;
+
+ auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) {
+ for (int k = 0; k < 4; ++k) {
+ auto yv = vld1q_s8_x2(y + 32*k);
+ dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]);
+ }
+ dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]);
+ dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]);
+ return vpaddq_s32(dot.val[0], dot.val[2]);
+ };
+
+ float32x4x2_t scales;
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ const float * dptr = (const float *)((const char*)vx + ix*bx);
+ auto d = vdupq_n_f32(dptr[0]*1.05f);
+ const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
+
+ for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0);
+
+ for (int i = 0; i < nb; ++i) {
+ auto u32 = *(const uint32_t *)x[i].scales;
+ auto s8_u32 = uint32x2_t{u32, u32 >> 4};
+ s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f));
+ auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32));
+ auto s16 = vmovl_s8(s8);
+ scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16))));
+ scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
+ const uint16_t * ql = (const uint16_t *)x[i].ql;
+ if constexpr (nrc_y == 1) {
+ const block_q8_0_x4& ybl = y[0][2*i+0];
+ const block_q8_0_x4& ybh = y[0][2*i+1];
+ auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
+ auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
+ int32x4x4_t suml = {};
+ int32x4x4_t sumh = {};
+ for (int ib = 0; ib < 4; ++ib) {
+ auto xl = trellis.next32(ql + 4*ib + 0, 4096);
+ auto xh = trellis.next32(ql + 4*ib + 16, 4096);
+ auto yl = vld1q_s8_x2(ybl.qs + 32*ib);
+ auto yh = vld1q_s8_x2(ybh.qs + 32*ib);
+ suml.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]);
+ sumh.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]);
+ }
+ auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]);
+ auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]);
+ auto sl = vpaddq_s32(sl1, sl2);
+ auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]);
+ auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]);
+ auto sh = vpaddq_s32(sh1, sh2);
+ accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl));
+ accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh));
+ } else {
+ for (int k = 0; k < 8; ++k) xv[k] = trellis.next32(ql + 4*k, 4096);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const block_q8_0_x4& ybl = y[iy][2*i+0];
+ const block_q8_0_x4& ybh = y[iy][2*i+1];
+ auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
+ auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
+ auto sumil = compute_dot(ybl.qs, xv+0);
+ auto sumih = compute_dot(ybh.qs, xv+4);
+ if constexpr (nrc_y == 1) {
+ accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil));
+ accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih));
+ } else {
+ accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil));
+ accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih));
+ }
+ }
+ }
+ }
+
+ if constexpr (nrc_y == 1) {
+ info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1])));
+ } else {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vaddvq_f32(accd[iy]));
+ }
+ }
+ }
+}
+
}
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
- //if (ne00%QK_K == 0 && ggml_type(typeB) == GGML_TYPE_F32 && ggml_type(typeA) == GGML_TYPE_IQ4_KT) {
- // IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels);
- // func16 = nullptr;
- // return true;
- //}
- if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F16) {
+ if (ne00%QK_K != 0) return false;
+
+ if (ggml_type(typeA) == GGML_TYPE_IQ4_KT) {
+ if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) {
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_0_x4_T, kernels);
+ func16 = nullptr;
+ return true;
+ }
return false;
}
- func16 = nullptr;
+ if (ggml_type(typeA) == GGML_TYPE_IQ2_KT) {
+ if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) {
+ IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_q8_0_x4_T, kernels);
+ func16 = nullptr;
+ return true;
+ }
+ return false;
+ }
+
+ if (ggml_type(typeB) != GGML_TYPE_F16) {
+ return false;
+ }
switch (typeA) {
case GGML_TYPE_IQ2_KT:
@@ -1022,14 +2017,16 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
return false;
}
+ func16 = nullptr;
+
return true;
}
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) {
switch (type) {
- case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
+ case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break;
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
- case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
+ case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break;
default: return false;
}
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 6925e6a6..f718e43e 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -236,9 +236,6 @@ struct MulMat {
static inline ggml_type is_dequant_better(ggml_type type, int nrc_y) {
#ifdef __AVX2__
switch (type) {
- case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type;
- case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type;
- case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type;
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_S : return nrc_y >= 16 ? GGML_TYPE_Q8_K_R8 : type;
@@ -267,13 +264,16 @@ struct MulMat {
case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
+ case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
+ case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
+ case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
default: break;
}
#else
switch (type) {
- case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
+ case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
- case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
+ case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
default: break;
}
#endif
@@ -815,7 +815,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_IQ2_KT:
case GGML_TYPE_IQ3_KT:
case GGML_TYPE_IQ4_KT:
- return ggml_type(typeB) == GGML_TYPE_F32 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false;
+ return iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16);
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index abd4be61..0384e49a 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -7397,7 +7397,7 @@ void dequantize_row_ms_i2s(const void * vx, float * y, int64_t k) {
}
namespace {
-template <int block_size, int group_size, int num_bits, bool is_abs = false>
+template <int block_size, int group_size, int num_bits, bool is_abs = false, bool is_int = false>
class QuantizerIQKT {
static_assert(group_size == 8 || group_size == 4);
static_assert(block_size >= 8 && block_size%8 == 0);
@@ -7408,7 +7408,7 @@ public:
constexpr static int kNg = kBlockSize/kGroupSize;
constexpr static int kNblock = kSuperBlockSize/kBlockSize;
constexpr static int kNumVal = 1 << num_bits; // i.e, 16 bits per group of 8
- constexpr static float kScale = 31.75f;
+ constexpr static float kScale = is_int ? 1.f : 31.75f;
constexpr static bool kVerbose = false;
QuantizerIQKT(int num_clusters, int num_neighbours, int offset = 4096);
@@ -7419,17 +7419,32 @@ public:
inline float find_best_inverse_scale(const float * xb, const float * weight, const int * best_idx) const;
static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) {
- constexpr uint32_t ka = 89226354;
- constexpr uint32_t kb = 64248484;
- constexpr uint32_t kmask = 0x8fff8fff;
- constexpr uint32_t km32 = 0x3b603b60;
uint32_t x = i + offset;
- for (int k = 0; k < kGroupSize; ++k) {
- x = ka*x + kb;
- uint32_t s = (x & kmask) ^ km32;
- float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16);
- if constexpr (is_abs) result[k] = scale*std::abs(val);
- else result[k] = scale*val;
+ if constexpr (is_int) {
+ constexpr uint32_t ka = 0xCBAC1FED;
+ uint32_t s;
+ auto i8 = (const int8_t *)&s;
+ for (int k = 0; k < kGroupSize; ++k) {
+ x = ka*x;
+ s = x & 0x3f3f3f3f;
+ if constexpr (is_abs) {
+ result[k] = scale*std::abs(i8[0] + i8[1] + i8[2] + i8[3] - 126.f);
+ } else {
+ result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f);
+ }
+ }
+ } else {
+ constexpr uint32_t ka = 89226354;
+ constexpr uint32_t kb = 64248484;
+ constexpr uint32_t kmask = 0x8fff8fff;
+ constexpr uint32_t km32 = 0x3b603b60;
+ for (int k = 0; k < kGroupSize; ++k) {
+ x = ka*x + kb;
+ uint32_t s = (x & kmask) ^ km32;
+ float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16);
+ if constexpr (is_abs) result[k] = scale*std::abs(val);
+ else result[k] = scale*val;
+ }
}
}
@@ -7478,14 +7493,15 @@ private:
float m_mid[4*kGroupSize];
};
-template <int block_size, int group_size, int num_bits, bool is_abs>
-QuantizerIQKT<block_size, group_size, num_bits, is_abs>::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) {
+template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int>
+QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) {
m_values.resize(kNumVal*kGroupSize);
float * data = m_values.data();
for (int i = 0; i < kNumVal; ++i) {
set_values(i, data, kScale, offset);
data += kGroupSize;
}
+ if (num_clusters == 0) return;
// Make 128 clusters.
// Note: we get a slightly better result by using 64 clusters
// at the expense of almost doubling the quantization time.
@@ -7494,8 +7510,8 @@ QuantizerIQKT<block_size, group_size, num_bits, is_abs>::QuantizerIQKT(int num_c
m_in_cluster = finalize_clusters(num_neighbours, m_values, m_clusters, m_c_values);
}
-template <int block_size, int group_size, int num_bits, bool is_abs>
-std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_scale(
+template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int>
+std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::find_best_scale(
const float * xb, const float * weight, const int * best_idx) const {
float sumqx = 0, sumq2 = 0;
#ifdef __AVX2__
@@ -7527,8 +7543,8 @@ std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>:
return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f);
}
-template <int block_size, int group_size, int num_bits, bool is_abs>
-float QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_inverse_scale(
+template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int>
+float QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::find_best_inverse_scale(
const float * xb, const float * weight, const int * best_idx) const {
float sumqx = 0, sumx2 = 0;
#ifdef __AVX2__
@@ -7560,8 +7576,8 @@ float QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_inverse
return sumx2 > 0 ? sumqx/sumx2 : 0.f;
}
-template <int block_size, int group_size, int num_bits, bool is_abs>
-void QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const {
+template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int>
+void QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const {
if (!d) {
std::memset(best_idx, 0, kNg*sizeof(int));
return;
@@ -7739,8 +7755,8 @@ void QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_match(fl
#endif
}
-template <int block_size, int group_size, int num_bits, bool is_abs>
-std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::finalize_clusters(int num_neighbours,
+template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int>
+std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::finalize_clusters(int num_neighbours,
const std::vector<float>& values, const std::vector<float>& clusters, std::vector<std::vector<float>>& c_values) {
int ncluster = clusters.size()/kGroupSize;
std::vector<std::vector<int>> p_in_cluster(ncluster);
@@ -7826,8 +7842,8 @@ std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits, is
return p_in_cluster;
}
-template <int block_size, int group_size, int num_bits, bool is_abs>
-std::vector<float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid) {
+template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int>
+std::vector<float> QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid) {
constexpr int ndim = kGroupSize;
GGML_ASSERT(points.size() % ndim == 0);
int npoint = points.size() / ndim;
@@ -7995,7 +8011,7 @@ std::vector<float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::clus
// ========================================== iq2_kt ====================================================
-using QuantizerIQ2KT = QuantizerIQKT<32, 8, 16>;
+using QuantizerIQ2KT = QuantizerIQKT<32, 8, 16, false, true>;
const QuantizerIQ2KT& iq2kt_quantizer() {
static std::mutex mutex;
@@ -8006,7 +8022,7 @@ const QuantizerIQ2KT& iq2kt_quantizer() {
}
void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights,
- float * qtmp) {
+ int * all_idx) {
constexpr float kSigmaScale = 2.0f;
using Q = QuantizerIQ2KT;
@@ -8025,6 +8041,11 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f
Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights);
+ float amax_row = 0;
+ for (int j = 0; j < n_per_row; ++j) {
+ amax_row = std::max(amax_row, std::abs(x[j]));
+ }
+
float amax_scale = 0, max_scale = 0;
for (int ibl = 0; ibl < nblock; ++ibl) {
@@ -8042,9 +8063,10 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f
float ax = std::abs(xb[j]);
amax = std::max(amax, ax);
}
- quantizer.find_best_match( amax/96.f, xb, weight, best_idx);
+ float scale_0 = std::max(90.f, 124.f*amax/amax_row);
+ quantizer.find_best_match( amax/scale_0, xb, weight, best_idx);
auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx);
- quantizer.find_best_match(-amax/96.f, xb, weight, best_idx + Q::kNg);
+ quantizer.find_best_match(-amax/scale_0, xb, weight, best_idx + Q::kNg);
auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx + Q::kNg);
auto idx = best_idx;
@@ -8052,12 +8074,7 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f
else {
scales[ib] = dm; idx += Q::kNg;
}
- auto qt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
- for (int ig = 0; ig < Q::kNg; ++ig) {
- auto q = quantizer.values() + idx[ig]*Q::kGroupSize;
- for (int j = 0; j < Q::kGroupSize; ++j) qt[j] = q[j];
- qt += Q::kGroupSize;
- }
+ for (int ig = 0; ig < Q::kNg; ++ig) all_idx[(ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize + ig] = idx[ig];
float abs_scale = std::abs(scales[ib]);
if (abs_scale > amax_scale) {
@@ -8080,20 +8097,22 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f
float sumqx = 0, sumq2 = 0;
for (int ibl = 0; ibl < nblock; ++ibl) {
const float * xb = x + ibl*Q::kSuperBlockSize;
- const float * qb = qtmp + ibl*Q::kSuperBlockSize;
const float * wb = all_weights + ibl*Q::kSuperBlockSize;
auto scales = all_scales + ibl*Q::kNblock;
for (int ib = 0; ib < Q::kNblock; ++ib) {
int ls = best_index_iq4nl(iq4k_values, id*scales[ib]);
float dl = iq4k_values[ls];
- for (int j = 0; j < Q::kBlockSize; ++j) {
- float q = dl*qb[j];
- sumqx += wb[j]*xb[j]*q;
- sumq2 += wb[j]*q*q;
+ for (int ig = 0; ig < Q::kNg; ++ig) {
+ auto qb = quantizer.values() + Q::kGroupSize*all_idx[(ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize + ig];
+ for (int j = 0; j < Q::kGroupSize; ++j) {
+ int jj = ig*Q::kGroupSize + j;
+ float q = dl*qb[j];
+ sumqx += wb[jj]*xb[jj]*q;
+ sumq2 += wb[jj]*q*q;
+ }
}
xb += Q::kBlockSize;
wb += Q::kBlockSize;
- qb += Q::kBlockSize;
}
}
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
@@ -8129,6 +8148,26 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f
float dl = d*ls;
quantizer.find_best_match(dl, xb, weight, best_idx);
+ auto prev_idx = all_idx + (ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize;
+
+ float mse1 = 0, mse2 = 0;
+ for (int ig = 0; ig < Q::kNg; ++ig) {
+ auto q1 = quantizer.values() + Q::kGroupSize*prev_idx[ig];
+ auto q2 = quantizer.values() + Q::kGroupSize*best_idx[ig];
+ for (int j = 0; j < Q::kGroupSize; ++j) {
+ int jj = ig*Q::kGroupSize + j;
+ float diff1 = xb[jj] - dl*q1[j];
+ float diff2 = xb[jj] - dl*q2[j];
+ mse1 += weight[jj]*diff1*diff1;
+ mse2 += weight[jj]*diff2*diff2;
+ }
+ }
+ if (mse1 < mse2) {
+ for (int ig = 0; ig < Q::kNg; ++ig) best_idx[ig] = prev_idx[ig];
+ } else {
+ for (int ig = 0; ig < Q::kNg; ++ig) prev_idx[ig] = best_idx[ig];
+ }
+
for (int j = 0; j < Q::kNg; ++j) {
qs[j] = best_idx[j];
auto xl = xb + Q::kGroupSize*j;
@@ -8196,10 +8235,10 @@ size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
auto row_size = ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row);
std::vector<float> scales(n_per_row/QuantizerIQ2KT::kBlockSize);
std::vector<float> weights(n_per_row);
- std::vector<float> xtmp(n_per_row);
+ std::vector<int> idx(n_per_row/QuantizerIQ2KT::kGroupSize);
char * qrow = (char *)dst;
for (int64_t row = 0; row < nrows; ++row) {
- quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data());
+ quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), idx.data());
src += n_per_row;
qrow += row_size;
}
@@ -8209,7 +8248,7 @@ size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
void dequantize_row_iq2_kt(const block_iq2_kt * x, float * y, int64_t k) {
assert(k % QuantizerIQ2KT::kSuperBlockSize == 0);
#ifdef __AVX2__
- if (iqk_dequantize_ktquants(GGML_TYPE_IQ2_KT, k, x, 0, y, 0, 1)) return;
+ //if (iqk_dequantize_ktquants(GGML_TYPE_IQ2_KT, k, x, 0, y, 0, 1)) return;
#endif
const int nb = k / QuantizerIQ2KT::kSuperBlockSize;
const float * dptr = (const float *)x;
@@ -8254,7 +8293,7 @@ void vec_dot_iq2_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx
namespace {
-using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true>;
+using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true, true>;
const QuantizerIQ3KT& iq3kt_quantizer() {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
@@ -8465,7 +8504,7 @@ size_t quantize_iq3_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) {
#ifdef __AVX2__
- if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return;
+ //if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return;
#endif
using Q = QuantizerIQ3KT;
constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
@@ -8521,7 +8560,7 @@ void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx
namespace{
-using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15>;
+using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15, false, true>;
const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) {
static std::mutex mutex;
@@ -8536,6 +8575,14 @@ const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) {
return *quantizer1;
}
+const QuantizerIQ4KT& iq4kt_dequantizer() {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ static std::unique_ptr<QuantizerIQ4KT> dequantizer;
+ if (!dequantizer) dequantizer = std::make_unique<QuantizerIQ4KT>(0, 0, 4096);
+ return *dequantizer;
+}
+
void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights) {
constexpr float kSigmaScale = 2.0f;
@@ -8546,7 +8593,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f
float * dptr = (float *)vy;
- block_iq4_kt * y = (block_iq4_kt *)(dptr + 2);
+ block_iq4_kt * y = (block_iq4_kt *)(dptr + 1);
auto& quantizer1 = iq4kt_quantizer();
auto& quantizer2 = iq4kt_quantizer(true);
@@ -8555,13 +8602,10 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f
Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights);
- float amax_row = 0, row_av = 0;
+ float amax_row = 0;
for (int j = 0; j < n_per_row; ++j) {
- row_av += x[j];
amax_row = std::max(amax_row, std::abs(x[j]));
}
- row_av /= n_per_row;
- dptr[1] = row_av;
if (!amax_row) {
dptr[0] = 0.f;
std::memset(y, 0, nblock*sizeof(block_iq4_kt));
@@ -8584,7 +8628,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f
const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
float amax = 0;
for (int j = 0; j < Q::kBlockSize; ++j) {
- xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av;
+ xaux[j] = xbl[ib*Q::kBlockSize+j];
float ax = std::abs(xaux[j]);
amax = std::max(amax, ax);
}
@@ -8593,7 +8637,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f
continue;
}
float best = 0;
- float scale_0 = std::max(92.f, 127.f*amax/amax_row);
+ float scale_0 = std::max(90.f, 124.f*amax/amax_row);
for (int itry = -kNtry; itry <= kNtry; ++itry) {
quantizer1.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx);
auto [dp, score_p] = quantizer1.find_best_scale(xaux, weight, best_idx);
@@ -8664,7 +8708,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f
for (int ib = 0; ib < Q::kNblock; ++ib) {
auto& quantizer = y[ibl].qs[ib] & 1 ? quantizer2 : quantizer1;
const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
- for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av;
+ for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j];
int ls = nearest_int(id*scales[ib]);
ls = std::min(ls, 63);
*(uint8_t *)(shb + ib) = ((ls + 64) << 1) | (shb[ib] & 1);
@@ -8724,7 +8768,7 @@ size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) {
#ifdef __AVX2__
- if (iqk_dequantize_ktquants(GGML_TYPE_IQ4_KT, k, x, 0, y, 0, 1)) return;
+ //if (iqk_dequantize_ktquants(GGML_TYPE_IQ4_KT, k, x, 0, y, 0, 1)) return;
#endif
using Q = QuantizerIQ4KT;
assert(k % Q::kSuperBlockSize == 0);
@@ -8732,23 +8776,20 @@ void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) {
const int nb = k / Q::kSuperBlockSize;
const float * dptr = (const float *)x;
const float d = dptr[0] * Q::kScale;
- const float row_av = dptr[1];
- x = (const block_iq4_kt *)(dptr + 2);
- auto& deq = iq4kt_quantizer();
+ x = (const block_iq4_kt *)(dptr + 1);
+ auto& deq = iq4kt_dequantizer();
for (int ibl = 0; ibl < nb; ++ibl) {
auto shb = x[ibl].qs;
auto ql = (const uint8_t *)(shb + Q::kNblock);
auto qh = ql + kNumGroups;
for (int ib = 0; ib < Q::kNblock; ++ib) {
int offset = shb[ib] & 1 ? 32768 + 4096 : 4096;
- //auto& deq = shb[ib] & 1 ? deq2 : deq1;
int ls = int((shb[ib] & 0xff) >> 1) - 64;
float sl = d * ls;
for (int ig = 0; ig < Q::kNg; ++ig) {
int jj = ib*Q::kNg+ig;
uint16_t idx = ql[jj] | ((qh[jj%(kNumGroups/2)] << (8 - 4*(jj/(kNumGroups/2)))) & 0xf00) | (((shb[ib] >> (8 + 3*ig)) & 7) << 12);
deq.set_values(idx, y, sl, offset);
- for (int j = 0; j < Q::kGroupSize; ++j) y[j] += row_av;
y += Q::kGroupSize;
}
}
diff --git a/src/llama.cpp b/src/llama.cpp
index dfd53337..af8ef9be 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -18627,6 +18627,7 @@ static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) {
new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4|| new_type == GGML_TYPE_IQ4_KS_R4 ||
new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 ||
new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4||
+ new_type == GGML_TYPE_IQ2_KT || new_type == GGML_TYPE_IQ3_KT || new_type == GGML_TYPE_IQ4_KT ||
new_type == GGML_TYPE_IQ5_KS || new_type == GGML_TYPE_IQ5_KS_R4) {
if (nx % QK_K != 0) {
LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));