summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/ggml-common.h12
-rw-r--r--ggml/src/ggml-cuda.cu1
-rw-r--r--ggml/src/ggml-cuda/common.cuh7
-rw-r--r--ggml/src/ggml-cuda/convert.cu31
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu42
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cuh4
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu3
-rw-r--r--ggml/src/ggml-metal.m29
-rw-r--r--ggml/src/ggml-metal.metal144
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c21
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp317
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp107
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
14 files changed, 706 insertions, 19 deletions
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 423797b6..5847d903 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -407,7 +407,7 @@ typedef struct {
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
//
-// Bitnet - implemented as 1.75 bpw
+// Bitnet - implemented as 1.625 bpw
// The block scale is a waste, but it allows us to plug it in without any additional
// changes to ggml.
//
@@ -418,13 +418,21 @@ typedef struct {
} block_iq1_bn;
static_assert(sizeof(block_iq1_bn) == 13, "wrong iq1_bn block size/padding");
//
-// Bitnet - implemented as 2.25 bpw
+// Bitnet - implemented as 2.0 bpw
//
#define QK_IQ2BN 64
typedef struct {
uint8_t qs[QK_IQ2BN/4];
} block_iq2_bn;
static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/padding");
+//
+// TriLM - implemented as 2.0625 bpw
+//
+typedef struct {
+ ggml_half d;
+ uint8_t qs[QK_K/4];
+} block_iq2_tn;
+static_assert(sizeof(block_iq2_tn) == sizeof(ggml_half) + QK_K/4, "wrong iqt_bn block size/padding");
// Used by IQ1_M quants
typedef union {
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index d34aa386..a115a1b4 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2759,6 +2759,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_TN:
return true;
default:
return false;
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index fbc52aa9..c18e865a 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -656,6 +656,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> {
};
template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_TN> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR2_K;
+ static constexpr int qi = QI2_K;
+};
+
+template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
static constexpr int qk = QK4_NL;
static constexpr int qr = QR4_NL;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index ed7e4bd0..47ab92f0 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -154,6 +154,27 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
}
template<typename dst_t>
+static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq2_tn * x = (const block_iq2_tn *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t n = tid/32;
+ const int64_t l = tid - 32*n;
+ const int64_t is = 8*n + l/16;
+
+ const uint8_t q = x[i].qs[32*n + l];
+ dst_t * y = yy + i*QK_K + 128*n;
+
+ float d = __half2float(x[i].d);
+ y[l+ 0] = d * ((q >> 0) & 3) - d;
+ y[l+32] = d * ((q >> 2) & 3) - d;
+ y[l+64] = d * ((q >> 4) & 3) - d;
+ y[l+96] = d * ((q >> 6) & 3) - d;
+}
+
+template<typename dst_t>
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
@@ -647,6 +668,12 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k
}
template<typename dst_t>
+static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -812,6 +839,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
+ case GGML_TYPE_IQ2_TN:
+ return dequantize_row_iq2_tn_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
@@ -871,6 +900,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
+ case GGML_TYPE_IQ2_TN:
+ return dequantize_row_iq2_tn_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index acb495d1..8def1547 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -469,6 +469,41 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1(
}
+#define VDR_IQ2_TN_Q8_1_MMVQ 1
+#define VDR_IQ2_TN_Q8_1_MMQ 4
+
+static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq2_tn * bq2 = (const block_iq2_tn *) vbq + kbx;
+
+ const int bq8_offset = QR2_K * (iqs / QI8_1);
+
+ const uint16_t * q16 = (const uint16_t *)bq2->qs + 2*iqs;
+ int v = q16[0] | (q16[1] << 16);
+
+ float sumf = 0;
+ for (int i = 0; i < QR2_K; ++ i) {
+ int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1);
+ float d8 = __low2float(bq8_1[bq8_offset + i].ds);
+ sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0));
+ v >>= 2;
+ }
+ return __half2float(bq2->d) * sumf;
+
+ //float sumf_d = 0;
+ //float sumf_m = 0;
+ //for (int i = 0; i < QR2_K; ++ i) {
+ // int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1);
+ // float2 d8 = __half22float2(bq8_1[bq8_offset + i].ds);
+ // sumf_d += d8.x * ggml_cuda_dp4a(v & 0x03030303, u, 0);
+ // sumf_m += d8.y;
+ // v >>= 2;
+ //}
+ //return __half2float(bq2->d) * (sumf_d - 0.125f * sumf_m);
+
+}
+
} // namespace
void mul_mat_vec_iq2_k_q8_1_cuda(
@@ -499,3 +534,10 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
+void mul_mat_vec_iq2_tn_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh
index 9a33af0d..3dc5f41c 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cuh
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh
@@ -16,3 +16,7 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
+void mul_mat_vec_iq2_tn_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
+
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 56bf3ebe..428d822f 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -426,6 +426,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ2_BN:
mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
+ case GGML_TYPE_IQ2_TN:
+ mul_mat_vec_iq2_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 48384923..d54d252c 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -88,6 +88,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K,
@@ -122,6 +123,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32,
@@ -152,6 +154,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32,
@@ -179,6 +182,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32,
@@ -206,6 +210,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32,
@@ -577,6 +582,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, get_rows_iq1_bn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN, get_rows_iq2_tn, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K, get_rows_iq2_k, true);
@@ -611,6 +617,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, mul_mv_iq1_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32, mul_mv_iq2_tn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_K_F32, mul_mv_iq2_k_f32, ctx->support_simdgroup_reduction);
@@ -641,6 +648,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, mul_mv_id_iq1_bn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32, mul_mv_id_iq2_tn_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32, mul_mv_id_iq2_k_f32, ctx->support_simdgroup_reduction);
@@ -668,6 +676,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, mul_mm_iq1_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32, mul_mm_iq2_tn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32, mul_mm_iq2_k_f32, ctx->support_simdgroup_mm);
@@ -695,6 +704,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, mul_mm_id_iq1_bn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32, mul_mm_id_iq2_tn_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32, mul_mm_id_iq2_k_f32, ctx->support_simdgroup_mm);
@@ -1728,6 +1738,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_TN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break;
@@ -1904,6 +1915,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32].pipeline;
} break;
+ case GGML_TYPE_IQ2_TN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_TN_F32].pipeline;
+ } break;
case GGML_TYPE_IQ4_NL:
{
nth0 = 4;
@@ -1972,7 +1989,7 @@ static enum ggml_status ggml_metal_graph_compute(
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
- src0t == GGML_TYPE_IQ3_K) {
+ src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2074,6 +2091,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_TN_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break;
@@ -2244,6 +2262,12 @@ static enum ggml_status ggml_metal_graph_compute(
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline;
} break;
+ case GGML_TYPE_IQ2_TN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_TN_F32].pipeline;
+ } break;
case GGML_TYPE_IQ4_NL:
{
nth0 = 4;
@@ -2323,7 +2347,7 @@ static enum ggml_status ggml_metal_graph_compute(
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
- src0t == GGML_TYPE_IQ3_K) {
+ src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2384,6 +2408,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break;
case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break;
+ case GGML_TYPE_IQ2_TN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_TN ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break;
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 53b2ddb8..1366905d 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -3330,6 +3330,129 @@ kernel void kernel_mul_mv_q2_K_f32(
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
+void kernel_mul_mv_iq2_tn_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_iq2_tn * x = (device const block_iq2_tn *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int step = sizeof(block_iq2_tn) * nb / 2;
+
+ const int ix = tiisg/8; // 0...3
+ const int it = tiisg%8; // 0...7
+ const int iq = it/4; // 0 or 1
+ const int ir = it%4; // 0...3
+ const int is = (8*ir)/16;// 0 or 1
+
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
+
+ for (int ib = ix; ib < nb; ib += 4) {
+
+ float sumy = 0.f;
+ for (int i = 0; i < 8; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0];
+ yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8];
+ yl[i+16] = y4[i+64]; sumy += yl[i+16];
+ yl[i+24] = y4[i+96]; sumy += yl[i+24];
+ }
+
+ device const half * dh = &x[ib].d;
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+ }
+ float dall = dh[0];
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f +
+ (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f +
+ (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f +
+ (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy);
+
+ qs += step;
+ dh += step;
+ }
+
+ y4 += 4 * QK_K;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq2_tn_f32")]]
+kernel void kernel_mul_mv_iq2_tn_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq2_tn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
void kernel_mul_mv_q3_K_f32_impl(
device const void * src0,
device const float * src1,
@@ -6009,7 +6132,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
}
template <typename type4x4>
-void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & reg) {
const float d = xb->d;
const float min = xb->dmin;
device const uint8_t * q = (device const uint8_t *)xb->qs;
@@ -6028,6 +6151,21 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
}
template <typename type4x4>
+void dequantize_iq2_tn(device const block_iq2_tn * xb, short il, thread type4x4 & reg) {
+ const half d = xb->d;
+ device const uint8_t * q = (device const uint8_t *)xb->qs + 32*(il/8) + 16*(il&1);
+
+ il = (il/2)%4;
+
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ const half dl = d * coef;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - d;
+ }
+}
+
+template <typename type4x4>
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
const half d_all = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs;
@@ -6892,6 +7030,7 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_tn, QK_NL, dequantize_iq2_tn>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
@@ -6926,6 +7065,7 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_tn, QK_NL, dequantize_iq2_tn>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
@@ -6960,6 +7100,7 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_tn, QK_NL, dequantize_iq2_tn>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
@@ -7175,6 +7316,7 @@ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_tn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_tn_f32_impl>>;
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 415249fb..9b3fddbc 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -14996,6 +14996,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ3_K: break;
case GGML_TYPE_IQ4_K: break;
case GGML_TYPE_IQ5_K: break;
+ case GGML_TYPE_IQ2_TN: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
{
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 4ce9948d..5c817030 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -882,6 +882,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K64,
.nrows = 1,
},
+ [GGML_TYPE_IQ2_TN] = {
+ .type_name = "iq2_tn",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq2_tn),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_tn,
+ .from_float = quantize_row_iq2_tn,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq2_tn_ref,
+ .vec_dot = vec_dot_iq2_tn_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
[GGML_TYPE_IQ4_NL] = {
.type_name = "iq4_nl",
.blck_size = QK4_NL,
@@ -3375,6 +3387,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break;
case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break;
case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break;
+ case GGML_FTYPE_MOSTLY_IQ2_TN: wtype = GGML_TYPE_IQ2_TN; break;
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break;
@@ -9628,6 +9641,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -10012,6 +10026,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -10146,6 +10161,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -13069,6 +13085,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -13263,6 +13280,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -13531,6 +13549,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -14126,6 +14145,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ2_TN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
@@ -20865,6 +20885,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ2_TN: result = quantize_iq2_tn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 3a81d3ac..db83b841 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -692,6 +692,16 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
};
+struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
+ DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ bits.prepare(x[i].qs);
+ }
+ Q2Bits bits;
+};
+
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
@@ -960,6 +970,16 @@ inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i *
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
}
+template <typename Q8>
+inline void compute_block_iq2tn(int iy, int i, float d, const Q8& q8, const __m512i * values, __m512 * accd) {
+ auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i));
+ auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(
+ _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0),
+ values[0], q8.load_quants64(iy, i, 0)), values[1], q8.load_quants64(iy, i, 1)),
+ values[2], q8.load_quants64(iy, i, 2)), values[3], q8.load_quants64(iy, i, 3));
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+}
+
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
@@ -985,14 +1005,22 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da
deq.new_block(i, q8, accm, scales);
for (int iy = 0; iy < nrc_y; ++iy) {
- //compute_block(iy, i, deq.d, q8, deq.bits.values, scales, accd);
- const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
- const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
- const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
- const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
- auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
- sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
- accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+ if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
+ auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i));
+ auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(
+ _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0),
+ deq.bits.values[0], q8.load_quants64(iy, i, 0)), deq.bits.values[1], q8.load_quants64(iy, i, 1)),
+ deq.bits.values[2], q8.load_quants64(iy, i, 2)), deq.bits.values[3], q8.load_quants64(iy, i, 3));
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+ } else {
+ const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
+ sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+ }
}
}
@@ -1034,19 +1062,33 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
- for (int kx = 0; kx < k_nx; ++kx) {
- compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
+ if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
+ for (int kx = 0; kx < k_nx; ++kx) {
+ compute_block_iq2tn(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, &accd);
+ }
+ } else {
+ for (int kx = 0; kx < k_nx; ++kx) {
+ compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
+ }
}
}
if (2*(nb/2) < nb) {
int i0 = 2*(nb/2);
deq[0]->new_block(i0, q8, &accm, scales);
- compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
+ if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
+ compute_block_iq2tn(0, i0, deq[0]->d, q8, deq[0]->bits.values, &accd);
+ } else {
+ compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
+ }
}
- auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
- info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
+ if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
+ info.store(ix, 0, _mm512_reduce_add_ps(accd));
+ } else {
+ auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
+ info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
+ }
}
}
@@ -1439,6 +1481,74 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
const __m256i mh = _mm256_set1_epi8(0x30);
};
+struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
+ DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ inline void new_block(int i) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ }
+
+ Q2Bits bits;
+};
+
+
+template <int nrc_y>
+IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+
+ Q8<nrc_y> q8(info);
+ DequantizerIQ2TN deq(vx, bx);
+
+ __m256 accd[nrc_y];
+ const auto m1 = _mm256_set1_epi16(1);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ __m256i sumi[nrc_y];
+ deq.new_block(i);
+
+ deq.prepare(i, 0);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ sumi[iy] = _mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)),
+ _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1)));
+ sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 2)),
+ _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 3))), sumi[iy]);
+ }
+ deq.prepare(i, 1);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 4)),
+ _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 5))), sumi[iy]);
+ sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 6)),
+ _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 7))), sumi[iy]);
+ sumi[iy] = _mm256_sub_epi16(sumi[iy], q8.load_bsums(iy, i));
+ }
+ if (i > 0) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]);
+ }
+ } else {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ accd[iy] = _mm256_mul_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])));
+ }
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+
+ }
+}
+
template <typename Dequantizer, int nrc_y>
static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
@@ -1931,7 +2041,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)),
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4))));
- accd[iy] = _mm256_add_epi32(dot, accd[iy]);
+ accd[iy] = i > 0 ? _mm256_add_epi32(dot, accd[iy]) : dot;
#endif
}
}
@@ -3156,6 +3266,21 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerQ2K>(mm);
break;
+ case GGML_TYPE_IQ2_TN:
+ assert (ne00 % QK_K == 0);
+#ifdef HAVE_FANCY_SIMD
+ MulMat::set_functions<DequantizerIQ2TN>(mm);
+#else
+ mm.funcs[0] = mul_mat_iq2tn_q8_K<1>;
+ mm.funcs[1] = mul_mat_iq2tn_q8_K<2>;
+ mm.funcs[2] = mul_mat_iq2tn_q8_K<3>;
+ mm.funcs[3] = mul_mat_iq2tn_q8_K<4>;
+ mm.funcs[4] = mul_mat_iq2tn_q8_K<5>;
+ //mm.funcs[5] = mul_mat_iq2tn_q8_K<6>;
+ //mm.funcs[6] = mul_mat_iq2tn_q8_K<7>;
+ //mm.funcs[7] = mul_mat_iq2tn_q8_K<8>;
+#endif
+ break;
case GGML_TYPE_Q3_K:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerQ3K>(mm);
@@ -4280,6 +4405,159 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
};
+struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
+ DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return true; }
+
+ //template <typename Q8>
+ //inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) {
+ // d = GGML_FP16_TO_FP32(x[i].d);
+ //}
+
+ inline void new_block(int i) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ }
+
+ template <typename Q8>
+ inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
+ vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
+
+ auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
+ vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
+
+ auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
+ vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
+
+ auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
+ vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
+ }
+ }
+ template <typename Q8>
+ inline void compute1(const Q8& q8, int i, int j, int32x4_t * sumi) {
+ auto q8b_1 = q8.load_quants(0, i, 4*j+0);
+ sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
+ vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
+
+ auto q8b_2 = q8.load_quants(0, i, 4*j+1);
+ sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
+ vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
+
+ q8b_1 = q8.load_quants(0, i, 4*j+2);
+ sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_1.val[0]),
+ vreinterpretq_s8_u8(bits.b2.val[1]), q8b_1.val[1]);
+
+ q8b_2 = q8.load_quants(0, i, 4*j+3);
+ sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_2.val[0]),
+ vreinterpretq_s8_u8(bits.b2.val[3]), q8b_2.val[1]);
+ }
+
+ IQK_ALWAYS_INLINE void prepare(int i, int j) {
+ bits.prepare(x[i].qs+32*j);
+ auto m1 = vdupq_n_s8(1);
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vsubq_s8(bits.b1.val[k], m1);
+ bits.b2.val[k] = vsubq_s8(bits.b2.val[k], m1);
+ }
+ }
+
+ Q2bits bits;
+
+ float d;
+};
+
+template <int nrc_y>
+void mul_mat_iq2tn_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y, block_q8_K> q8(info);
+
+ DequantizerIQ2TN deq(vx, bx, nrc_y);
+ float32x4_t acc[nrc_y];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ int32x4_t sumi[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
+
+ deq.new_block(i);
+ deq.prepare(i, 0);
+ deq.compute(q8, i, 0, sumi);
+ deq.prepare(i, 1);
+ deq.compute(q8, i, 1, sumi);
+
+ if (i > 0) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
+ }
+ } else {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vmulq_f32(vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
+ }
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vaddvq_f32(acc[iy]));
+ }
+ }
+}
+void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<1, block_q8_K> q8(info);
+
+ DequantizerIQ2TN deq(vx, bx, 1);
+
+ auto m1 = vdup_n_s16(-1);
+ float32x4_t acc[2];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ int32x4_t sumi[2] = {};
+ deq.new_block(i);
+ auto bsums = q8.load_bsums(0, i);
+ bsums.val[0] = vaddq_s32(bsums.val[0], bsums.val[1]);
+ sumi[0] = vmlal_s16(sumi[0], vget_low_s16 (bsums.val[0]), m1);
+ sumi[1] = vmlal_s16(sumi[1], vget_high_s16(bsums.val[0]), m1);
+ deq.bits.prepare(deq.x[i].qs);
+ deq.compute1(q8, i, 0, sumi);
+ deq.bits.prepare(deq.x[i].qs+32);
+ deq.compute1(q8, i, 1, sumi);
+
+ auto vd = vdupq_n_f32(deq.d*q8.scale(0, i));
+ if (i > 0) {
+ acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd);
+ acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd);
+ } else {
+ acc[0] = vmulq_f32(vcvtq_f32_s32(sumi[0]), vd);
+ acc[1] = vmulq_f32(vcvtq_f32_s32(sumi[1]), vd);
+ }
+
+ }
+
+ acc[0] = vaddq_f32(acc[0], acc[1]);
+ info.store(ix, 0, vaddvq_f32(acc[0]));
+ }
+}
+
template <int nrc_y, typename Dequantizer>
void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -5269,6 +5547,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_Q2_K:
MulMat::set_functions<DequantizerQ2K>(m);
break;
+ case GGML_TYPE_IQ2_TN:
+ //MulMat::set_functions<DequantizerIQ2TN>(m);
+ m.funcs[0] = mul_mat_iq2tn_K_q8_K_1;
+ m.funcs[1] = mul_mat_iq2tn_K_q8_K_T<2>;
+ m.funcs[2] = mul_mat_iq2tn_K_q8_K_T<3>;
+ m.funcs[3] = mul_mat_iq2tn_K_q8_K_T<4>;
+ m.funcs[4] = mul_mat_iq2tn_K_q8_K_T<5>;
+ m.funcs[5] = mul_mat_iq2tn_K_q8_K_T<6>;
+ m.funcs[6] = mul_mat_iq2tn_K_q8_K_T<7>;
+ m.funcs[7] = mul_mat_iq2tn_K_q8_K_T<8>;
+ break;
case GGML_TYPE_Q3_K:
MulMat::set_functions<DequantizerQ3K>(m);
break;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index c840fabf..1cba1532 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -1514,3 +1514,110 @@ size_t quantize_iq5_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
}
return nrows * nblock * sizeof(block_iq5_k);
}
+
+//
+// ========================== IQ2_TN
+//
+
+void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) {
+ GGML_ASSERT(k%QK_K == 0);
+
+ int nb = k/QK_K;
+
+ auto quantize = [] (float xmax, float x) {
+ return x < -0.5f*xmax ? 0 : x < 0.5f*xmax ? 1 : 2;
+ };
+
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ auto xb = x + QK_K*ibl;
+ float max = xb[0];
+ for (int j = 0; j < QK_K; ++j) {
+ float ax = fabsf(xb[j]);
+ max = std::max(ax, max);
+ }
+ y[ibl].d = GGML_FP32_TO_FP16(max);
+ auto qs = y[ibl].qs;
+ for (int l = 0; l < QK_K/128; ++l) {
+ for (int j = 0; j < 32; ++j) {
+ qs[j] = quantize(max, xb[j]) | (quantize(max, xb[j+32]) << 2) | (quantize(max, xb[j+64]) << 4) | (quantize(max, xb[j+96]) << 6);
+ }
+ xb += 128;
+ qs += 32;
+ }
+ }
+}
+
+void quantize_row_iq2_tn(const float * x, void * y, int64_t k) {
+ quantize_row_iq2_tn_ref(x, (block_iq2_tn *)y, k);
+}
+
+size_t quantize_iq2_tn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * /*imatrix*/) {
+ auto row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row);
+ char * qrow = (char *)dst;
+ for (int row = 0; row < nrows; ++row) {
+ quantize_row_iq2_tn_ref(src, (block_iq2_tn *)qrow, n_per_row);
+ qrow += row_size;
+ src += n_per_row;
+ }
+ return row_size*nrows;
+}
+
+void dequantize_row_iq2_tn(const block_iq2_tn * x, float * y, int64_t k) {
+ GGML_ASSERT(k%QK_K == 0);
+ int nb = k/QK_K;
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ float d = GGML_FP16_TO_FP32(x[ibl].d);
+ auto qs = x[ibl].qs;
+ for (int l = 0; l < QK_K/128; ++l) {
+ for (int j = 0; j < 32; ++j) {
+ y[j+ 0] = d*((qs[j] >> 0) & 3) - d;
+ y[j+32] = d*((qs[j] >> 2) & 3) - d;
+ y[j+64] = d*((qs[j] >> 4) & 3) - d;
+ y[j+96] = d*((qs[j] >> 6) & 3) - d;
+ }
+ y += 128;
+ qs += 32;
+ }
+ }
+}
+
+void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ GGML_UNUSED(nrc);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(bs);
+
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_TN, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+
+ const int nb = n / QK_K;
+
+ const block_iq2_tn * x = (const block_iq2_tn *)vx;
+ const block_q8_K * y = (const block_q8_K *)vy;
+
+ float sumf = 0;
+
+ for (int i = 0; i < nb; i++) {
+ float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ auto qs = x[i].qs;
+ auto q8 = y[i].qs;
+ int sumi1 = 0, sumi2 = 0, sumi3 = 0,sumi4 = 0;
+ for (int j = 0; j < QK_K/16; ++j) sumi1 -= y[i].bsums[j];
+ for (int l = 0; l < QK_K/128; ++l) {
+ for (int j = 0; j < 32; ++j) {
+ sumi1 += q8[j+ 0] * (qs[j] & 0x03);
+ sumi2 += q8[j+32] * (qs[j] & 0x0c);
+ sumi3 += q8[j+64] * (qs[j] & 0x30);
+ sumi4 += q8[j+96] * (qs[j] & 0xc0);
+ }
+ q8 += 128;
+ qs += 32;
+ }
+ sumf += d * (sumi1 + 0.25f*sumi2 + 0.0625f*sumi3 + 0.015625f*sumi4);
+ }
+ *s = sumf;
+}
+
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 0295eb99..80a9012b 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -37,6 +37,12 @@ size_t quantize_iq5_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
void dequantize_row_iq5_k(const block_iq5_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_iq5_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_iq2_tn_ref(const float * GGML_RESTRICT x, block_iq2_tn * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_tn(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq2_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq2_tn(const block_iq2_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq2_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
#ifdef __cplusplus
}
#endif