summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/quantize/quantize.cpp4
-rw-r--r--ggml/include/ggml.h2
-rw-r--r--ggml/src/ggml-common.h7
-rw-r--r--ggml/src/ggml-cuda.cu12
-rw-r--r--ggml/src/ggml-cuda/convert.cu19
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu17
-rw-r--r--ggml/src/ggml-metal.metal66
-rw-r--r--ggml/src/ggml-quants.c2
-rw-r--r--ggml/src/ggml.c84
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp69
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp23
-rw-r--r--src/llama.cpp4
12 files changed, 171 insertions, 138 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index c6153e45..c11b8631 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -28,8 +28,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", },
{ "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", },
- { "IQ1_TN", LLAMA_FTYPE_MOSTLY_IQ1_TN, " 1.69 bpw quantization (TriLM)", },
- { "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.06 bpw quantization (TriLM)", },
+ { "IQ1_TN", LLAMA_FTYPE_MOSTLY_IQ1_TN, " 1.63 bpw quantization (TriLM)", },
+ { "IQ2_TN", LLAMA_FTYPE_MOSTLY_IQ2_TN, " 2.00 bpw quantization (TriLM)", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 5b46a70d..6ac30b0f 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -744,6 +744,7 @@ extern "C" {
GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor);
GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
+ // TODO: remove the following from the public API to avoid unnecessary assumptions about data layout
GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type);
GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
@@ -2517,6 +2518,7 @@ extern "C" {
int64_t ncols; // number of columns to process simultaneously
ggml_gemv_t gemv;
ggml_gemm_t gemm;
+ int64_t row_meta_size;
} ggml_type_traits_t;
GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 40a4b53c..bb0c4864 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -400,14 +400,13 @@ static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/paddi
// TriLM - implemented as 2.0625 bpw
//
typedef struct {
- uint8_t qs[54];
+ uint8_t qs[52];
} block_iq1_tn;
-static_assert(sizeof(block_iq1_tn) == 54, "wrong iq1_tn block size/padding");
+static_assert(sizeof(block_iq1_tn) == 52, "wrong iq1_tn block size/padding");
typedef struct {
- ggml_half d;
uint8_t qs[QK_K/4];
} block_iq2_tn;
-static_assert(sizeof(block_iq2_tn) == sizeof(ggml_half) + QK_K/4, "wrong iqt_bn block size/padding");
+static_assert(sizeof(block_iq2_tn) == QK_K/4, "wrong iqt_bn block size/padding");
// Used by IQ1_M quants
typedef union {
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 87d7e17e..ca57efbd 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -1180,19 +1180,22 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
const int64_t nb3 = src->nb[3];
const enum ggml_type type = src->type;
const int64_t ts = ggml_type_size(type);
+ const int64_t rs = ggml_row_size(type, ne0);
const int64_t bs = ggml_blck_size(type);
int64_t i1_diff = i1_high - i1_low;
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
- if (nb0 == ts && nb1 == ts*ne0/bs) {
+ if (nb0 == ts && nb1 == rs) {
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream);
} else if (nb0 == ts) {
+ // TODO: this only works if the row does not contain meta data
return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream);
} else {
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
const void * rx = (const void *) ((const char *) x + i1*nb1);
- void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+ void * rd = (void *) (dst_ptr + i1*rs);
// pretend the row is a matrix with cols=1
+ // TODO: this only works if the row does not contain meta data
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream);
if (r != cudaSuccess) {
return r;
@@ -1441,8 +1444,7 @@ static void ggml_cuda_op_mul_mat(
const int64_t i02_divisor = ne12 / ne02;
- const size_t src0_ts = ggml_type_size(src0->type);
- const size_t src0_bs = ggml_blck_size(src0->type);
+ const size_t src0_rs = ggml_row_size(src0->type, ne00);
const size_t q8_1_ts = sizeof(block_q8_1);
const size_t q8_1_bs = QK8_1;
@@ -1608,7 +1610,7 @@ static void ggml_cuda_op_mul_mat(
}
// for split tensors the data begins at i0 == i0_offset_low
- char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
+ char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * ne01*src0_rs;
float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset;
float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 4b1be7c1..c74b030b 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -154,19 +154,23 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
}
template<typename dst_t>
-static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ int64_t n_per_row, int64_t row_size) {
- const int64_t i = blockIdx.x;
- const block_iq2_tn * x = (const block_iq2_tn *) vx;
+ int64_t ii = blockIdx.x;
+ int64_t row = (QK_K * ii) / n_per_row;
+ const char * cx = (const char *)vx + row * row_size;
+ float d = *(const float *)cx;
+ const block_iq2_tn * x = (const block_iq2_tn *)(cx + sizeof(float));
+ int64_t i = ii - (row*n_per_row)/QK_K;
const int64_t tid = threadIdx.x;
const int64_t n = tid/32;
const int64_t l = tid - 32*n;
const uint8_t q = x[i].qs[32*n + l];
- dst_t * y = yy + i*QK_K + 128*n;
+ dst_t * y = yy + ii*QK_K + 128*n;
- float d = __half2float(x[i].d);
y[l+ 0] = d * ((q >> 0) & 3) - d;
y[l+32] = d * ((q >> 2) & 3) - d;
y[l+64] = d * ((q >> 4) & 3) - d;
@@ -743,8 +747,9 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t n
template<typename dst_t>
static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
- const int nb = k / QK_K;
- dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y);
+ const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row);
+ const int nb = (k + 255) / 256;
+ dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y, n_per_row, row_size);
}
template<typename dst_t>
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index a890f6b3..b2c32c0c 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -519,7 +519,8 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1(
static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
- const block_iq2_tn * bq2 = (const block_iq2_tn *) vbq + kbx;
+ float scale = *(const float *)vbq;
+ const block_iq2_tn * bq2 = (const block_iq2_tn *)((const char *)vbq + sizeof(float)) + kbx;
const int bq8_offset = QR2_K * (iqs / QI8_1);
@@ -533,19 +534,7 @@ static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1(
sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0));
v >>= 2;
}
- return __half2float(bq2->d) * sumf;
-
- //float sumf_d = 0;
- //float sumf_m = 0;
- //for (int i = 0; i < QR2_K; ++ i) {
- // int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1);
- // float2 d8 = __half22float2(bq8_1[bq8_offset + i].ds);
- // sumf_d += d8.x * ggml_cuda_dp4a(v & 0x03030303, u, 0);
- // sumf_m += d8.y;
- // v >>= 2;
- //}
- //return __half2float(bq2->d) * (sumf_d - 0.125f * sumf_m);
-
+ return scale * sumf;
}
static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1(
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 259fa609..e2e45029 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -3778,21 +3778,21 @@ void kernel_mul_mv_iq2_tn_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
+ const int row_size = nb*sizeof(block_iq2_tn) + 4;
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ const uint offset0 = ((i12/r2)*ne01 + (i13/r3)*ne01*ne02)*row_size;
- device const block_iq2_tn * x = (device const block_iq2_tn *) src0 + ib_row + offset0;
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ device const char * cx = (device const char *) src0 + first_row*row_size + offset0;
+ device const float * y = (device const float*) src1 + r1*ne10 + im*ne00*ne1;
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
-
- const int step = sizeof(block_iq2_tn) * nb / 2;
+ float drow[N_DST];
const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
@@ -3802,6 +3802,8 @@ void kernel_mul_mv_iq2_tn_f32_impl(
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
+ for (int row = 0; row < N_DST; row++) drow[row] = *((device const float *)(cx + row*row_size));
+
for (int ib = ix; ib < nb; ib += 4) {
float sumy = 0.f;
@@ -3812,7 +3814,7 @@ void kernel_mul_mv_iq2_tn_f32_impl(
yl[i+24] = y4[i+96]; sumy += yl[i+24];
}
- device const half * dh = &x[ib].d;
+ device const block_iq2_tn * x = (device const block_iq2_tn *)(cx + 4);
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
for (int row = 0; row < N_DST; row++) {
@@ -3829,14 +3831,12 @@ void kernel_mul_mv_iq2_tn_f32_impl(
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
}
- float dall = dh[0];
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f +
- (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f +
- (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f +
- (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy);
+ sumf[row] += (acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f +
+ (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f +
+ (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f +
+ (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy;
- qs += step;
- dh += step;
+ qs += row_size/2;
}
y4 += 4 * QK_K;
@@ -3845,7 +3845,7 @@ void kernel_mul_mv_iq2_tn_f32_impl(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = drow[row]*all_sum;
}
}
}
@@ -5504,7 +5504,7 @@ void kernel_mul_mv_iq1_tn_f32_impl(
// Why are we not passing in src0->nb[0]?
// But because we are not, we need to use this hack
- const uint row_size = sizeof(block_iq1_tn)*(ne00/QK_K);
+ const uint row_size = 2+sizeof(block_iq1_tn)*(ne00/QK_K);
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
@@ -6850,16 +6850,14 @@ void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & re
template <typename type4x4>
void dequantize_iq2_tn(device const block_iq2_tn * xb, short il, thread type4x4 & reg) {
- const half d = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs + 32*(il/8) + 16*(il&1);
il = (il/2)%4;
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
- const half dl = d * coef;
for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * (q[i] & mask) - d;
+ reg[i/4][i%4] = coef * (q[i] & mask) - 1;
}
}
@@ -7446,30 +7444,28 @@ struct DefaultDequantizer {
short il;
};
-template <typename T4x4>
-struct DequantizerIQ1TN {
+template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(device const Block *, short, thread T4x4&)>
+struct DequantizerRS{
using type4x4 = T4x4;
- using Block = block_iq1_bn;;
- DequantizerIQ1TN(device const char * cx, short il = 0) : il(il) {
- d = *(device const half *)cx;
- x = (device const Block *)(cx + sizeof(half));
+ DequantizerRS(device const char * cx, short il = 0) : il(il) {
+ d = *(device const Scale *)cx;
+ x = (device const Block *)(cx + sizeof(Scale));
}
inline void convert(thread T4x4& t) const {
- dequantize_iq1_bn(x, il, t);
+ dequantize(x, il, t);
t *= d;
}
inline void convert(int64_t ind, thread T4x4& t) {
- dequantize_iq1_bn(x + ind/4, ind%4, t);
+ dequantize(x + ind/nl, ind%nl, t);
t *= d;
}
inline void next() {
- constexpr int short nl = 4;
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2+nl-1)/nl : x;
}
device const Block * x;
short il;
- half d;
+ Scale d;
};
// each block_q contains 16*nl weights
@@ -7821,7 +7817,6 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_tn, QK_NL, dequantize_iq2_tn>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
@@ -7842,7 +7837,8 @@ template [[host_name("kernel_get_rows_iq5_k")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_iq6_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq6_k, QK_NL, dequantize_iq6_k>;
template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_bn, 4, dequantize_iq1_bn>;
template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_bn, 4, dequantize_iq2_bn>;
-template [[host_name("kernel_get_rows_iq1_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerIQ1TN<float4x4>>;
+template [[host_name("kernel_get_rows_iq1_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>;
+template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>;
//
// matrix-matrix multiplication
@@ -7862,7 +7858,6 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_1, 2, dequantize_q5_1>>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q8_0, 2, dequantize_q8_0>>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q2_K, QK_NL, dequantize_q2_K>>;
-template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_tn, QK_NL, dequantize_iq2_tn>>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q3_K, QK_NL, dequantize_q3_K>>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_K, QK_NL, dequantize_q4_K>>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_K, QK_NL, dequantize_q5_K>>;
@@ -7883,7 +7878,8 @@ template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq6_k, QK_NL, dequantize_iq6_k>>;
template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_bn, 4, dequantize_iq1_bn>>;
template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_bn, 4, dequantize_iq2_bn>>;
-template [[host_name("kernel_mul_mm_iq1_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerIQ1TN<half4x4>>;
+template [[host_name("kernel_mul_mm_iq1_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>;
+template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>;
//
// indirect matrix-matrix multiplication
@@ -7900,7 +7896,6 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_1, 2, dequantize_q5_1>>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q8_0, 2, dequantize_q8_0>>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q2_K, QK_NL, dequantize_q2_K>>;
-template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq2_tn, QK_NL, dequantize_iq2_tn>>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q3_K, QK_NL, dequantize_q3_K>>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_K, QK_NL, dequantize_q4_K>>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_K, QK_NL, dequantize_q5_K>>;
@@ -7921,7 +7916,8 @@ template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_k, QK_NL, dequantize_iq4_k>>;
template [[host_name("kernel_mul_mm_id_iq5_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq5_k, QK_NL, dequantize_iq5_k>>;
template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq6_k, QK_NL, dequantize_iq6_k>>;
-template [[host_name("kernel_mul_mm_id_iq1_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerIQ1TN<half4x4>>;
+template [[host_name("kernel_mul_mm_id_iq1_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn>>;
+template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_tn, float, 16, dequantize_iq2_tn>>;
//
// matrix-vector multiplication
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index a9a25761..b0e70bcc 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -14798,7 +14798,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
return false;
}
- if (nbytes % ggml_type_size(type) != 0) {
+ if (type != GGML_TYPE_IQ2_TN && type != GGML_TYPE_IQ1_TN && nbytes % ggml_type_size(type) != 0) {
fprintf(stderr, "%s: invalid size %zu for type %s (type size = %zu)\n", __func__, nbytes, ggml_type_name(type), ggml_type_size(type));
return false;
}
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 08b292b7..2804accd 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -651,24 +651,28 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.blck_size = 1,
.type_size = sizeof(int8_t),
.is_quantized = false,
+ .row_meta_size = 0,
},
[GGML_TYPE_I16] = {
.type_name = "i16",
.blck_size = 1,
.type_size = sizeof(int16_t),
.is_quantized = false,
+ .row_meta_size = 0,
},
[GGML_TYPE_I32] = {
.type_name = "i32",
.blck_size = 1,
.type_size = sizeof(int32_t),
.is_quantized = false,
+ .row_meta_size = 0,
},
[GGML_TYPE_I64] = {
.type_name = "i64",
.blck_size = 1,
.type_size = sizeof(int64_t),
.is_quantized = false,
+ .row_meta_size = 0,
},
[GGML_TYPE_F64] = {
.type_name = "f64",
@@ -676,6 +680,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(double),
.is_quantized = false,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_F32] = {
.type_name = "f32",
@@ -685,6 +690,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_F16] = {
.type_name = "f16",
@@ -697,6 +703,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
.vec_dot_type = GGML_TYPE_F16,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q4_0] = {
.type_name = "q4_0",
@@ -717,6 +724,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
#else
.nrows = 1,
#endif
+ .row_meta_size = 0,
},
[GGML_TYPE_Q4_1] = {
.type_name = "q4_1",
@@ -733,6 +741,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
#else
.nrows = 1,
#endif
+ .row_meta_size = 0,
},
[4] = { // GGML_TYPE_Q4_2
.type_name = "DEPRECATED",
@@ -745,6 +754,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_COUNT,
.nrows = 1,
+ .row_meta_size = 0,
},
[5] = { // GGML_TYPE_Q4_3
.type_name = "DEPRECATED",
@@ -757,6 +767,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_COUNT,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q5_0] = {
.type_name = "q5_0",
@@ -773,6 +784,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q5_1] = {
.type_name = "q5_1",
@@ -785,6 +797,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_q5_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q8_0] = {
.type_name = "q8_0",
@@ -806,6 +819,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
#else
.nrows = 1,
#endif
+ .row_meta_size = 0,
},
[GGML_TYPE_Q8_1] = {
.type_name = "q8_1",
@@ -816,6 +830,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
.vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
@@ -828,6 +843,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_q2_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q3_K] = {
.type_name = "q3_K",
@@ -840,6 +856,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_q3_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q4_K] = {
.type_name = "q4_K",
@@ -852,6 +869,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_q4_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q5_K] = {
.type_name = "q5_K",
@@ -864,6 +882,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_q5_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q6_K] = {
.type_name = "q6_K",
@@ -876,6 +895,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ2_XXS] = {
.type_name = "iq2_xxs",
@@ -888,6 +908,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ2_XS] = {
.type_name = "iq2_xs",
@@ -900,6 +921,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_iq2_xs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ3_XXS] = {
.type_name = "iq3_xxs",
@@ -912,6 +934,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ3_S] = {
.type_name = "iq3_s",
@@ -924,6 +947,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_iq3_s_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ2_S] = {
.type_name = "iq2_s",
@@ -936,6 +960,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_iq2_s_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ1_S] = {
.type_name = "iq1_s",
@@ -948,6 +973,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_iq1_s_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ1_M] = {
.type_name = "iq1_m",
@@ -960,6 +986,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_iq1_m_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ1_BN] = {
.type_name = "iq1_bn",
@@ -972,6 +999,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_iq1_bn_q8_K64,
.vec_dot_type = GGML_TYPE_Q8_K64,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ2_BN] = {
.type_name = "iq2_bn",
@@ -984,6 +1012,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_iq2_bn_q8_K64,
.vec_dot_type = GGML_TYPE_Q8_K64,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ2_TN] = {
.type_name = "iq2_tn",
@@ -996,6 +1025,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = vec_dot_iq2_tn_q8_k,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 4,
},
[GGML_TYPE_IQ1_TN] = {
.type_name = "iq1_tn",
@@ -1008,6 +1038,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = vec_dot_iq1_tn_q8_k,
.vec_dot_type = GGML_TYPE_Q8_K64,
.nrows = 1,
+ .row_meta_size = 2,
},
[GGML_TYPE_IQ4_NL] = {
.type_name = "iq4_nl",
@@ -1020,6 +1051,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_iq4_nl_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ4_XS] = {
.type_name = "iq4_xs",
@@ -1032,6 +1064,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_iq4_xs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q8_K] = {
.type_name = "q8_K",
@@ -1039,6 +1072,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q8_K),
.is_quantized = true,
.from_float = quantize_row_q8_K,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q8_K64] = {
.type_name = "q8_K64",
@@ -1046,6 +1080,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q8_K64),
.is_quantized = true,
.from_float = quantize_row_q8_K64,
+ .row_meta_size = 0,
},
[GGML_TYPE_BF16] = {
.type_name = "bf16",
@@ -1058,6 +1093,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
.vec_dot_type = GGML_TYPE_BF16,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q4_0_4_4] = {
.type_name = "q4_0_4x4",
@@ -1074,6 +1110,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.ncols = 4,
.gemv = ggml_gemv_q4_0_4x4_q8_0,
.gemm = ggml_gemm_q4_0_4x4_q8_0,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q4_0_4_8] = {
.type_name = "q4_0_4x8",
@@ -1090,6 +1127,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.ncols = 4,
.gemv = ggml_gemv_q4_0_4x8_q8_0,
.gemm = ggml_gemm_q4_0_4x8_q8_0,
+ .row_meta_size = 0,
},
[GGML_TYPE_Q4_0_8_8] = {
.type_name = "q4_0_8x8",
@@ -1106,6 +1144,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.ncols = 8,
.gemv = ggml_gemv_q4_0_8x8_q8_0,
.gemm = ggml_gemm_q4_0_8x8_q8_0,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ2_K] = {
.type_name = "iq2_k",
@@ -1118,6 +1157,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = vec_dot_iq2_k_q8_k,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ3_K] = {
.type_name = "iq3_k",
@@ -1130,6 +1170,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = vec_dot_iq3_k_q8_k,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ4_K] = {
.type_name = "iq4_k",
@@ -1142,6 +1183,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = vec_dot_iq4_k_q8_k,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ5_K] = {
.type_name = "iq5_k",
@@ -1154,6 +1196,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = vec_dot_iq5_k_q8_k,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
[GGML_TYPE_IQ6_K] = {
.type_name = "iq6_k",
@@ -1166,6 +1209,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = vec_dot_iq6_k_q8_k,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
+ .row_meta_size = 0,
},
};
@@ -3585,6 +3629,10 @@ GGML_CALL int64_t ggml_nrows(const struct ggml_tensor * tensor) {
return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
}
+GGML_CALL int64_t ggml_blck_size(enum ggml_type type) {
+ return type_traits[type].blck_size;
+}
+
GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
size_t nbytes;
size_t blck_size = ggml_blck_size(tensor->type);
@@ -3595,7 +3643,7 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
}
}
else {
- nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
+ nbytes = tensor->nb[1]; //tensor->ne[0]*tensor->nb[0]/blck_size;
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
@@ -3608,17 +3656,13 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
}
-GGML_CALL int64_t ggml_blck_size(enum ggml_type type) {
- return type_traits[type].blck_size;
-}
-
GGML_CALL size_t ggml_type_size(enum ggml_type type) {
return type_traits[type].type_size;
}
GGML_CALL size_t ggml_row_size(enum ggml_type type, int64_t ne) {
assert(ne % ggml_blck_size(type) == 0);
- return ggml_type_size(type)*ne/ggml_blck_size(type);
+ return type_traits[type].row_meta_size + ggml_type_size(type)*ne/ggml_blck_size(type);
}
double ggml_type_sizef(enum ggml_type type) {
@@ -3764,7 +3808,7 @@ static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) {
if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) {
return false;
}
- next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type);
+ next_nb = ggml_row_size(tensor->type, tensor->ne[0]); //next_nb*tensor->ne[0]/ggml_blck_size(tensor->type) + type_traits[tensor->type].row_meta_size;
for (int i = 1; i < GGML_MAX_DIMS; i++) {
if (tensor->ne[i] != 1) {
if (i > n) {
@@ -4227,7 +4271,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
}
result->nb[0] = ggml_type_size(type);
- result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));
+ result->nb[1] = ggml_row_size(type, ne[0]);
for (int i = 2; i < GGML_MAX_DIMS; i++) {
result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
}
@@ -13023,8 +13067,8 @@ static void ggml_compute_forward_mul_mat(
for (int64_t i12 = 0; i12 < ne12; i12++) {
if (counter++ % nth == ith) {
if (!iqk_mul_mat(ne01, ne11, ne00,
- src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
- src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
+ src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
0, 1)) goto IQK_MulMat_Not_Available1;
}
@@ -13036,8 +13080,8 @@ static void ggml_compute_forward_mul_mat(
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!iqk_mul_mat(ne01, ne11, ne00,
- src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
- src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
+ src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
ith, nth)) goto IQK_MulMat_Not_Available1;
return;
@@ -13123,8 +13167,8 @@ UseGgmlGemm1:;
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!iqk_mul_mat(ne01, ne11, ne00,
- src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
- vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type),
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
+ vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
ith, nth)) goto IQK_MulMat_Not_Available2;
return;
@@ -13353,8 +13397,8 @@ static void ggml_compute_forward_mul_mat_id(
#if GGML_USE_IQK_MULMAT
if (ne13 == 1 && dst->type == GGML_TYPE_F32) {
if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
- src0->type, (const char *)src0_cur, nb01/ggml_type_size(src0->type),
- vec_dot_type, (const char *)wdata, row_size/ggml_type_size(vec_dot_type),
+ src0->type, (const char *)src0_cur, nb01, ///ggml_type_size(src0->type),
+ vec_dot_type, (const char *)wdata, row_size, ///ggml_type_size(vec_dot_type),
(float *)dst->data, nb1, nb2,
matrix_rows + cur_a*ne12, ith, nth)) goto IQK_MulMat_Not_Available;
continue;
@@ -13870,7 +13914,7 @@ static void ggml_compute_forward_softcap(
default:
{
GGML_ASSERT(false);
- } break;
+ }
}
}
@@ -13986,7 +14030,7 @@ static void ggml_compute_forward_softcap_max(
default:
{
GGML_ASSERT(false);
- } break;
+ }
}
}
@@ -18652,11 +18696,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_SOFTCAP:
{
GGML_ASSERT(false); // TODO: not implemented
- } break;
+ }
case GGML_OP_SOFT_CAP_MAX:
{
GGML_ASSERT(false); // TODO: not implemented
- } break;
+ }
case GGML_OP_SET:
{
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 7543d895..33b0a0d5 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -140,8 +140,9 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00,
return false;
}
- auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA));
- auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB));
+ size_t row_size_qx = strideA; //*ggml_type_size(ggml_type(typeA));
+ size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB));
+ //if (ith == 0) printf("%s: ne00 = %d, row_size_qx = %d, strideA = %d\n", __func__, int(ne00), int(row_size_qx), int(strideA));
auto nrc_x = (Nx + nth - 1)/nth;
auto first_x = ith*nrc_x;
@@ -165,8 +166,8 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
return false;
}
- auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA));
- auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB));
+ size_t row_size_qx = strideA; //*ggml_type_size(ggml_type(typeA));
+ size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB));
int nrc_x = (Nx + nth - 1)/nth;
int first_x = ith*nrc_x;
if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
@@ -378,11 +379,17 @@ struct ScaleIQ4XS {
const __m128i m32 = _mm_set1_epi16(-32);
};
-template <typename Block>
+template <typename Block, bool per_row_scale = false>
struct BaseDequantizer {
BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}
inline void new_row(int ix) {
- x = (const Block *)((const char *)vx + bx*ix);
+ if constexpr (per_row_scale) {
+ const float * dptr = (const float *)((const char *)vx + bx*ix);
+ d = *dptr;
+ x = (const Block *)(dptr + 1);
+ } else {
+ x = (const Block *)((const char *)vx + bx*ix);
+ }
}
const void * vx;
@@ -700,14 +707,13 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
};
-struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
+struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> {
DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) {
new_block(i);
}
inline void new_block(int i) {
- d = GGML_FP16_TO_FP32(x[i].d);
bits.prepare(x[i].qs);
}
Q2Bits bits;
@@ -1158,7 +1164,7 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
deq1.new_block(i);
deq2.new_block(i);
- float d = 0.5f*(deq1.d + deq2.d); // The scale is supposed to be per per tensor, so we can use the same scale for both rows
+ //float d = 0.5f*(deq1.d + deq2.d); // The scale is supposed to be per per tensor, so we can use the same scale for both rows
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi_scales_256 = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i));
@@ -1176,7 +1182,7 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[3], q8q);
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q);
// The scale is supposed to be per per tensor, so we can use the same scale
- auto vd = _mm512_set1_ps(d*q8.scale(iy, i));
+ auto vd = _mm512_set1_ps(/*d* */q8.scale(iy, i));
accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
// Leaving this here just in case ternary models start using per row scales
@@ -1187,8 +1193,8 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
}
for (int iy = 0; iy < nrc_y; ++iy) {
- info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0]));
- info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1]));
+ info.store(ix+0, iy, deq1.d*_mm512_reduce_add_ps(accd[2*iy+0]));
+ info.store(ix+1, iy, deq2.d*_mm512_reduce_add_ps(accd[2*iy+1]));
}
}
@@ -4104,14 +4110,23 @@ struct Q2bits {
}
};
-template <typename block_q>
+template <typename block_q, bool has_row_scale = false>
struct BaseDequantizer {
BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
- inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); }
+ inline void new_row(int ix) {
+ if constexpr (has_row_scale) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ d = *dptr;
+ x = (const block_q *)(dptr + 1);
+ } else {
+ x = (const block_q *)((const char *)vx + ix*bx);
+ }
+ }
const void * vx;
const block_q * x;
const size_t bx;
const int nrc;
+ float d;
};
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
@@ -4133,7 +4148,6 @@ struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
Q4bits bits;
Scales8 s8;
- float d;
};
struct HighBit5 {
@@ -4202,7 +4216,6 @@ struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
uint8x16x2_t hbits;
- float d;
};
inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
@@ -4256,7 +4269,6 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
const uint8x16_t mhb = vdupq_n_u8(0x30);
- float d;
};
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
@@ -4317,7 +4329,6 @@ struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
uint8x16_t mask;
HighBit3 h;
- float d;
};
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
@@ -4389,7 +4400,6 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
Q2bits bits;
- float d;
};
// ============================= i-quants
@@ -4453,7 +4463,6 @@ struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
const int8x16_t values;
const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
- float d;
};
struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
@@ -4503,7 +4512,6 @@ struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
const uint8x16_t hm = vdupq_n_u8(0x10);
uint8x16x2_t hbits;
- float d;
};
struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
@@ -4538,7 +4546,6 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
const int8x16x4_t values;
const uint8x16_t hm = vdupq_n_u8(0x30);
- float d;
};
struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
@@ -4570,7 +4577,6 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x000000001101f3e1));
const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
- float d;
};
struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
@@ -4630,7 +4636,6 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
const uint8x16_t sign_mask = vreinterpretq_u8_u64(uint64x2_t{0x0808040402020101, 0x8080404020201010});
const uint8x16_t sign_shuffle = load_sign_shuffle();
- float d;
};
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
@@ -4688,7 +4693,6 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602};
- float d;
};
struct SimpleBits {
@@ -4747,7 +4751,6 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
uint32x4x4_t data;
SimpleBits bits;
- float d;
};
inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) {
@@ -4793,7 +4796,6 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
SimpleBits bits;
- float d;
};
@@ -4857,7 +4859,6 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
SimpleBits bits;
SignHelper sh;
- float d;
};
@@ -4891,8 +4892,6 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
SimpleBits bits;
uint32x4x2_t gas;
- float d;
-
};
struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
@@ -4951,11 +4950,9 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
SignHelper sh;
uint32x4x2_t gas;
- float d;
-
};
-struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
+struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn, true> {
DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
@@ -4966,9 +4963,7 @@ struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
// d = GGML_FP16_TO_FP32(x[i].d);
//}
- inline void new_block(int i) {
- d = GGML_FP16_TO_FP32(x[i].d);
- }
+ inline void new_block(int) { }
template <typename Q8>
inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
@@ -5019,8 +5014,6 @@ struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
}
Q2bits bits;
-
- float d;
};
template <int nrc_y>
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 42441584..28bad18e 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -1972,15 +1972,15 @@ void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) {
auto quantize = [] (float xmax, float x) {
return x < -0.5f*xmax ? 0 : x < 0.5f*xmax ? 1 : 2;
};
+ int n = k;
+ float max = x[0];
+ for (int j = 1; j < n; ++j) max = std::max(max, fabsf(x[j]));
+
+ *(float *)y = max;
+ y = (block_iq2_tn *)((float *)y + 1);
for (int ibl = 0; ibl < nb; ++ibl) {
auto xb = x + QK_K*ibl;
- float max = xb[0];
- for (int j = 0; j < QK_K; ++j) {
- float ax = fabsf(xb[j]);
- max = std::max(ax, max);
- }
- y[ibl].d = GGML_FP32_TO_FP16(max);
auto qs = y[ibl].qs;
for (int l = 0; l < QK_K/128; ++l) {
for (int j = 0; j < 32; ++j) {
@@ -1992,7 +1992,7 @@ void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) {
}
}
-void quantize_row_iq2_tn(const float * x, void * y, int64_t k) {
+void quantize_row_iq2_tn(const float * x, void * y, int64_t k) {
quantize_row_iq2_tn_ref(x, (block_iq2_tn *)y, k);
}
@@ -2009,9 +2009,11 @@ size_t quantize_iq2_tn(const float * src, void * dst, int64_t nrows, int64_t n_p
void dequantize_row_iq2_tn(const block_iq2_tn * x, float * y, int64_t k) {
GGML_ASSERT(k%QK_K == 0);
+ const float * dptr = (const float *)x;
+ float d = *dptr;
+ x = (const block_iq2_tn *)(dptr + 1);
int nb = k/QK_K;
for (int ibl = 0; ibl < nb; ++ibl) {
- float d = GGML_FP16_TO_FP32(x[ibl].d);
auto qs = x[ibl].qs;
for (int l = 0; l < QK_K/128; ++l) {
for (int j = 0; j < 32; ++j) {
@@ -2039,13 +2041,14 @@ void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t
const int nb = n / QK_K;
- const block_iq2_tn * x = (const block_iq2_tn *)vx;
+ const float * dptr = (const float *)vx;
+ const float d = *dptr;
+ const block_iq2_tn * x = (const block_iq2_tn *)(dptr + 1);
const block_q8_K * y = (const block_q8_K *)vy;
float sumf = 0;
for (int i = 0; i < nb; i++) {
- float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
auto qs = x[i].qs;
auto q8 = y[i].qs;
int sumi1 = 0, sumi2 = 0, sumi3 = 0,sumi4 = 0;
diff --git a/src/llama.cpp b/src/llama.cpp
index 0eea948a..df57c071 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -4498,9 +4498,9 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ5_K: return "IQ5_K - 5.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ6_K: return "IQ6_K - 6.6 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_BN: return "IQ1_BN - 1.625 bpw Bitnet";
- case LLAMA_FTYPE_MOSTLY_IQ1_TN: return "IQ1_TN - 1.6875 bpw TriLM";
+ case LLAMA_FTYPE_MOSTLY_IQ1_TN: return "IQ1_TN - 1.625 bpw TriLM";
case LLAMA_FTYPE_MOSTLY_IQ2_BN: return "IQ2_BN - 2.00 bpw Bitnet";
- case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQ2_TN - 2.06 bpw TriLM";
+ case LLAMA_FTYPE_MOSTLY_IQ2_TN: return "IQ2_TN - 2.00 bpw TriLM";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4";