diff options
-rw-r--r-- | ggml/include/ggml.h | 6 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 12 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 18 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cuh | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 3 | ||||
-rw-r--r-- | ggml/src/ggml.c | 21 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 347 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 6 |
9 files changed, 414 insertions, 10 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 144e87f5..a0bcc67f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -393,7 +393,8 @@ extern "C" { GGML_TYPE_IQ3_K = 38, GGML_TYPE_IQ4_K = 39, GGML_TYPE_IQ5_K = 40, - GGML_TYPE_IQ2_TN = 41, + GGML_TYPE_IQ6_K = 41, + GGML_TYPE_IQ2_TN = 42, GGML_TYPE_COUNT, }; @@ -444,7 +445,8 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ2_TN = 34, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ6_K = 34, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_TN = 35, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 5847d903..2fbac06a 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -142,6 +142,9 @@ typedef sycl::half2 ggml_half2; #define QI5_XS (QK_K / (4*QR5_XS)) #define QR5_XS 2 +#define QI6_XS (QK_K / (4*QR6_XS)) +#define QR6_XS 2 + #define QI3_S (QK_K / (4*QR3_S)) #define QR3_S 4 @@ -493,6 +496,15 @@ typedef struct { } block_iq5_k; static_assert(sizeof(block_iq5_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/8 + 3*QK_K/64, "wrong iq5_k block size/padding"); +typedef struct { + ggml_half d; + uint16_t extra; + int8_t scales[QK_K/16]; + uint8_t qs[QK_K/2]; + uint8_t qh[QK_K/4]; +} block_iq6_k; +static_assert(sizeof(block_iq6_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/4 + QK_K/16, "wrong iq6_k block size/padding"); + #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c18e865a..07a53bcd 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -705,6 +705,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ6_K> { + static constexpr int qk = QK_K; + static constexpr int qr = QR6_XS; + static constexpr int qi = QI6_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> { static constexpr int qk = QK_K; static constexpr int qr = QR3_S; diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 8def1547..ae5e6a3c 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -251,6 +251,16 @@ __device__ __forceinline__ float vec_dot_iq5_k_q8_1( return d5 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2); } +#define VDR_IQ6_K_Q8_1_MMVQ 4 +#define VDR_IQ6_K_Q8_1_MMQ 4 + +// TODO +__device__ __forceinline__ float vec_dot_iq6_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + return 0; +} + static const __device__ uint32_t iq2k_table[512] = { 0xe1e1e1e1, 0xe1e1e1f3, 0xe1e1e101, 0xe1e1e111, 0xe1e1f3e1, 0xe1e1f3f3, 0xe1e1f301, 0xe1e1f311, 0xe1e101e1, 0xe1e101f3, 0xe1e10101, 0xe1e10111, 0xe1e111e1, 0xe1e111f3, 0xe1e11101, 0xe1e11111, @@ -534,10 +544,16 @@ void mul_mat_vec_iq5_k_q8_1_cuda( iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +void mul_mat_vec_iq6_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + void mul_mat_vec_iq2_tn_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } - diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 3dc5f41c..7af8e570 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -16,6 +16,10 @@ void mul_mat_vec_iq5_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); +void mul_mat_vec_iq6_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + void mul_mat_vec_iq2_tn_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 428d822f..9eb3fa4f 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -447,6 +447,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ5_K: mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ6_K: + mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ3_S: mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5c817030..b5fdb96d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1040,6 +1040,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_IQ6_K] = { + .type_name = "iq6_k", + .blck_size = QK_K, + .type_size = sizeof(block_iq6_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq6_k, + .from_float = quantize_row_iq6_k, + .from_float_ref = (ggml_from_float_t)quantize_row_iq6_k_ref, + .vec_dot = vec_dot_iq6_k_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, }; // For internal test use @@ -3394,6 +3406,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break; case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break; case GGML_FTYPE_MOSTLY_IQ5_K: wtype = GGML_TYPE_IQ5_K; break; + case GGML_FTYPE_MOSTLY_IQ6_K: wtype = GGML_TYPE_IQ6_K; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break; @@ -9648,6 +9661,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -10033,6 +10047,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -10168,6 +10183,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -13092,6 +13108,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -13287,6 +13304,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -13556,6 +13574,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q4_0_4_4: @@ -14152,6 +14171,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q8_K: @@ -20892,6 +20912,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ5_K: result = quantize_iq5_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ6_K: result = quantize_iq6_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 1cba1532..f12367a4 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -1516,6 +1516,346 @@ size_t quantize_iq5_k(const float * src, void * dst, int64_t nrows, int64_t n_pe } // +// ============================================== iq6_K +// +void dequantize_row_iq6_k(const block_iq6_k * x, float * y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const int8_t * sl = x[i].scales; + + uint16_t extra = x[i].extra; + + int shift = 0; + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + + float dl1 = d * sl[4*ib64 + 0]; + float dl2 = d * sl[4*ib64 + 1]; + float dl3 = d * sl[4*ib64 + 2]; + float dl4 = d * sl[4*ib64 + 3]; + int m1 = extra & 1 ? -127 : -125; + int m2 = extra & 2 ? -127 : -125; + int m3 = extra & 4 ? -127 : -125; + int m4 = extra & 8 ? -127 : -125; + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl1 * ((((qs[j+ 0] & 0xf) | (((qh[j+ 0] >> shift) & 0x03) << 4)) << 2) + m1); + y[j+16] = dl2 * ((((qs[j+16] & 0xf) | (((qh[j+16] >> shift) & 0x03) << 4)) << 2) + m2); + y[j+32] = dl3 * ((((qs[j+ 0] >> 4) | (((qh[j+ 0] >> shift) & 0x0c) << 2)) << 2) + m3); + y[j+48] = dl4 * ((((qs[j+16] >> 4) | (((qh[j+16] >> shift) & 0x0c) << 2)) << 2) + m4); + } + y += 64; + qs += 32; + extra >>= 4; + shift += 4; + if (shift == 8) { qh += 32; shift = 0; } + } + + } +} + +void vec_dot_iq6_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ6_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } + + // TODO + //const int nb = n / QK_K; + + //const block_iq5_k * x = (const block_iq5_k *)vx; + //const block_q8_K * y = (const block_q8_K *)vy; + + //float sumf = 0; + + //for (int i = 0; i < nb; i++) { + + // const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + // const uint8_t * qs = x[i].qs; + // const uint8_t * qh = x[i].qh; + // const uint8_t * sl = x[i].scales_l; + // const uint8_t * sh = x[i].scales_h; + // const int8_t * q8 = y[i].qs; + + // uint16_t extra = x[i].extra; + + // int shift = 0; + // int sumb = 0; + // for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + + // int dl1 = (((sl[2*ib64+0] & 0xf) | ((sh[ib64] << 4) & 0x30)) - 32); + // int dl2 = (((sl[2*ib64+0] >> 4) | ((sh[ib64] << 2) & 0x30)) - 32); + // int dl3 = (((sl[2*ib64+1] & 0xf) | ((sh[ib64] >> 0) & 0x30)) - 32); + // int dl4 = (((sl[2*ib64+1] >> 4) | ((sh[ib64] >> 2) & 0x30)) - 32); + // const int8_t * values1 = iq5nl_values + ((extra & 1) << 5); + // const int8_t * values2 = iq5nl_values + ((extra & 2) << 4); + // const int8_t * values3 = iq5nl_values + ((extra & 4) << 3); + // const int8_t * values4 = iq5nl_values + ((extra & 8) << 2); + // int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + // for (int j = 0; j < 16; ++j) { + // sumi1 += q8[j+ 0] * values1[(qs[j+ 0] & 0xf) | (((qh[j+ 0] >> shift) & 1) << 4)]; + // sumi2 += q8[j+16] * values2[(qs[j+16] & 0xf) | (((qh[j+16] >> shift) & 1) << 4)]; + // sumi3 += q8[j+32] * values3[(qs[j+ 0] >> 4) | (((qh[j+ 0] >> shift) & 2) << 3)]; + // sumi4 += q8[j+48] * values4[(qs[j+16] >> 4) | (((qh[j+16] >> shift) & 2) << 3)]; + // } + // sumb += dl1 * sumi1 + dl2 * sumi2 + dl3 * sumi3 + dl4 * sumi4; + // q8 += 64; + // qs += 32; + // extra >>= 4; + // shift += 2; + // } + // sumf += d * sumb; + + //} + + //*s = sumf; + +} + +namespace { + +void quantize_row_iq6_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) { + const int ntry = 5; + const float step = 1.f; + + block_iq6_k * y = (block_iq6_k *)vy; + + float scales[QK_K/16]; + float weight[16]; + + uint8_t L[QK_K]; + + //int nerr = 0; + + for (int ibl = 0; ibl < n_per_row/QK_K; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq6_k)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + + const float * xbl = x + ibl*QK_K; + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = 2*sumx2/QK_K; + + float max_abs_scale = 0; + uint16_t extra = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*16; + for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < 16; ++j) weight[j] = 1.f; //0.25f*sigma2 + xb[j]*xb[j]; + } + float amax = 0; + for (int j = 0; j < 16; ++j) { + float ax = fabsf(xb[j]); + amax = std::max(ax, amax); + } + if (!amax) { + scales[ib] = 0; + continue; + } + float d = amax/127; + float id = 0.25f/d; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + int lp = nearest_int(id*xb[j] + 31.25f); + lp = std::max(0, std::min(63, lp)); + float qp = 4*lp - 125; + sumqx_p += w*qp*xb[j]; + sumq2_p += w*qp*qp; + int lm = nearest_int(id*xb[j] + 31.75f); + lm = std::max(0, std::min(63, lm)); + float qm = 4*lm - 127; + sumqx_m += w*qm*xb[j]; + sumq2_m += w*qm*qm; + //printf("x = %g, lp = %d, qp = %g -> %g, lm = %d, qm = %g -> %g\n", xb[j], lp, qp, d*qp, lm, qm, d*qm); + } + d = sumqx_p/sumq2_p; + float best = d*sumqx_p; + bool is_shifted = false; + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; is_shifted = true; + } + for (int itry = -ntry; itry <= ntry; ++itry) { + //0.25/amax*127 => 31.75/amax + id = (itry*step + 31.75f)/amax; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + int l = nearest_int(id*xb[j] + 31.25f); + l = std::max(0, std::min(63, l)); + float q = 4*l - 125; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = nearest_int(id*xb[j] + 31.75f); + l = std::max(0, std::min(63, l)); + q = 4*l - 127; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; + } + } + scales[ib] = d; + if (is_shifted) extra |= (1 << ib); + + max_abs_scale = std::max(max_abs_scale, amax); + + //float mse = 0; + //id = 0.25f/d; + //float xmin = is_shifted ? 31.75f : 31.25f; + //for (int j = 0; j < 16; ++j) { + // int l = nearest_int(id*xb[j] + xmin); + // l = std::max(0, std::min(63, l)); + // float diff = xb[j] - 4*d*(l - xmin); + // mse += diff*diff; + //} + //printf("Block %d: %g\n", ib, sqrtf(mse/16)); + + } + + if (!max_abs_scale) continue; + float d = max_abs_scale/255; + y[ibl].d = GGML_FP32_TO_FP16(d); + y[ibl].extra = extra; + + float id = 1/d; + + std::memset(L, 0, QK_K); + float sumqx = 0, sumq2 = 0; + //float tot_mse = 0; + for (int ib = 0; ib < QK_K/16; ++ib) { + int ls = nearest_int(id*scales[ib]); + ls = MAX(0, MIN(255, ls)); + y[ibl].scales[ib] = ls; + float dl = d * ls; + if (dl) { + const float xmin = y[ibl].extra & (1 << ib) ? 31.75f : 31.25f; + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + ibl*QK_K + ib*16; + for (int j = 0; j < 16; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < 16; ++j) weight[j] = 1.f; //0.25f*sigma2 + xb[j]*xb[j]; + } + float idl = 0.25f/dl; + int ib32 = ib/2; + int offset = 16*(ib%2); + uint8_t * qs = y[ibl].qs + 32*(ib32/2) + offset; + uint8_t * qh = y[ibl].qh + 32*(ib32/4) + offset; + //float mse1 = 0, mse2 = 0; + for (int j = 0; j < 16; ++j) { + int l = nearest_int(idl*xb[j] + xmin); + l = std::max(0, std::min(63, l)); + L[16*ib + j] = l; + qs[j] |= ((l & 0xf) << 4*(ib32%2)); + qh[j] |= ((l >> 4) << 2*(ib32%4)); + float w = weight[j]; + float q = 4*(l - xmin)*ls; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + //float diff = xb[j] - 4*d*ls*(l - xmin); + //mse1 += diff*diff; + int ll = ((qs[j] >> 4*(ib32%2)) & 0xf) | (((qh[j] >> 2*(ib32%4)) << 4) & 0x30); + if (ll != l) { + printf("Oops: l = %d, ll = %d, qs = %u, qh = %u, ib = %d\n", l, ll, qs[j], qh[j], ib); + exit(1); + } + //diff = xb[j] - 4*d*ls*(ll - xmin); + //mse2 += diff*diff; + } + //printf("Block %d: %g, %g\n", ib, sqrtf(mse1/16), sqrtf(mse2/16)); + //tot_mse += mse1; + } + } + //printf("=============== rmse = %g, Old scale: %g New scale: %g\n", sqrtf(tot_mse/256), d, sumqx/sumq2); + if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(sumqx/sumq2); + + //d = GGML_FP16_TO_FP32(y[ibl].d); + //tot_mse = 0; + //for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + // const float * xb = xbl + 32*ib32; + // float dl1 = d * y[ibl].scales[2*ib32+0]; + // float dl2 = d * y[ibl].scales[2*ib32+1]; + // int min1 = y[ibl].extra & (1 << (2*ib32+0)) ? -127 : -125; + // int min2 = y[ibl].extra & (1 << (2*ib32+1)) ? -127 : -125; + // const uint8_t * qs = y[ibl].qs + 32*(ib32/2); + // const uint8_t * qh = y[ibl].qh + 32*(ib32/4); + // for (int j = 0; j < 16; ++j) { + // int l = ((qs[j] >> 4*(ib32%2)) & 0xf) | (((qh[j] >> 2*(ib32%4)) << 4) & 0x30); + // if (l != L[32*ib32 + j]) { + // ++nerr; + // printf("Oops: %d vs %u for ib32 = %d, j = %d. qs = %u (0x%02x), qh = %u (0x%02x)\n", l, L[32*ib32 + j], ib32, j, qs[j], qs[j], qh[j], qh[j]); + // if (nerr > 10) exit(1); + // } + // float diff = dl1*(4*l + min1) - xb[j]; + // tot_mse += diff*diff; + // //printf(" %d %d %g\n", l, 4*l + min1, diff); + // } + // for (int j = 16; j < 32; ++j) { + // int l = ((qs[j] >> 4*(ib32%2)) & 0xf) | (((qh[j] >> 2*(ib32%4)) << 4) & 0x30); + // if (l != L[32*ib32 + j]) { + // ++nerr; + // printf("Oops: %d vs %u for ib32 = %d, j = %d. qs = %u (0x%02x), qh = %u (0x%02x)\n", l, L[32*ib32 + j], ib32, j, qs[j], qs[j], qh[j], qh[j]); + // if (nerr > 10) exit(1); + // } + // float diff = dl2*(4*l + min2) - xb[j]; + // tot_mse += diff*diff; + // //printf(" %d %d %g\n", l, 4*l + min2, diff); + // } + //} + //printf(" after adjusting scale: d = %g, rmse = %g\n", d, sqrtf(tot_mse/256)); + + } + +} + +} + +void quantize_row_iq6_k_ref(const float * x, block_iq6_k * y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq6_k(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq6_k(const float * x, void * vy, int64_t k) { + assert(k % QK_K == 0); + block_iq6_k * y = (block_iq6_k *)vy; + quantize_row_iq6_k_ref(x, y, k); +} + +size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq6_k_impl(src, (void *)qrow, n_per_row, imatrix); + src += n_per_row; + qrow += nblock*sizeof(block_iq6_k); + } + return nrows * nblock * sizeof(block_iq6_k); +} + +// // ========================== IQ2_TN // @@ -1582,13 +1922,6 @@ void dequantize_row_iq2_tn(const block_iq2_tn * x, float * y, int64_t k) { } void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - GGML_UNUSED(nrc); - GGML_UNUSED(bx); - GGML_UNUSED(by); - GGML_UNUSED(bs); - if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_TN, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { return; } diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 80a9012b..3c5d27a4 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -37,6 +37,12 @@ size_t quantize_iq5_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, void dequantize_row_iq5_k(const block_iq5_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq5_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq6_k_ref(const float * GGML_RESTRICT x, block_iq6_k * GGML_RESTRICT y, int64_t k); +void quantize_row_iq6_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq6_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq6_k(const block_iq6_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq6_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void quantize_row_iq2_tn_ref(const float * GGML_RESTRICT x, block_iq2_tn * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_tn(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_iq2_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); |