diff options
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 14 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 23 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 137 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cuh | 15 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cu | 12 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cuh | 253 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 12 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu | 5 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 77 | ||||
-rw-r--r-- | ggml/src/ggml.c | 25 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_gemm_ktquants.cpp | 1043 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 12 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 161 | ||||
-rw-r--r-- | src/llama.cpp | 1 |
16 files changed, 1668 insertions, 132 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8f3d2a26..a0cdab28 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -579,6 +579,20 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ3_KT> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KT> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> { static constexpr int qk = QK_K; static constexpr int qr = QR4_XS; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 01b7250e..b40079a3 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -340,6 +340,12 @@ inline __device__ int nearest_int(float fval) { return (i & 0x007fffff) - 0x00400000; } +int __device__ __forceinline__ trellis_next_int(uint32_t& val) { + constexpr uint32_t ka = 0xCBAC1FED; + val = ka*val; + return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126); +} + float __device__ __forceinline__ trellis_next(uint32_t& val) { constexpr uint32_t ka = 89226354; constexpr uint32_t kb = 64248484; @@ -367,9 +373,9 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst dst_t * y = yy + ii*QK_K + 8*ib; const uint16_t * ql = (const uint16_t *)x[i].ql; uint32_t idx = ql[ib] + 4096; - const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f; + const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 1.05f; for (int j = 0; j < 8; ++j) { - y[j] = dl * trellis_next(idx); + y[j] = dl * trellis_next_int(idx); } } @@ -388,10 +394,10 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst dst_t * y = yy + ii*QK_K + 8*ib; const uint16_t * ql = (const uint16_t *)x[i].ql; uint32_t idx = ql[ib] + 4096; - const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f; + const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 1.01f; //1.015f; uint8_t mask = 1 << (ib/4); for (int j = 0; j < 8; ++j) { - y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f); + y[j] = dl * std::abs(trellis_next_int(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f); } } @@ -401,9 +407,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst int64_t ii = blockIdx.x; int64_t row = (QK_K * ii) / n_per_row; const float * dptr = (const float *)((const char *)vx + row * row_size); - float scale = dptr[0] * 31.75f * 1.01f; - float row_av = dptr[1]; - const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + float scale = dptr[0] * 1.00f; + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1); const int64_t i = ii - (row*n_per_row)/QK_K; constexpr int kNumGroups = 64; @@ -423,8 +428,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst int ls = ((shb[ib32] & 0xff) >> 1) - 64; const float dl = scale * ls; for (int j = 0; j < 4; ++j) { - y[j+0] = dl * trellis_next(idx1) + row_av; - y[j+4] = dl * trellis_next(idx2) + row_av; + y[j+0] = dl * trellis_next_int(idx1); + y[j+4] = dl * trellis_next_int(idx2); } } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index e5a224b4..c19215de 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -433,6 +433,119 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1( *result += dl * __low2float(bq8_1[ib32].ds) * sumi; } +__device__ __forceinline__ void vec_dot_iq4_kt_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + + float scale = *(const float *)vbq; + const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + sizeof(float)) + kbx; + + // iqs is 0...28 + const int ib32 = iqs/4; // Why iqs/4 ? + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + //const int8_t * q8 = bq8_1[ib32].qs; + const int ls = (bq4->qs[ib32] & 0xff) >> 1; + const float dl = scale * (ls - 64); + const uint32_t idx0 = ((bq4->qs[ib32] & 1) << 15) + 4096; + auto ql = (const uint8_t *)(bq4->qs + 8); + auto qh = ql + 64; + ql += 8*ib32; + qh += 8*(ib32%4); + const int shift1 = 8 - 4*(ib32/4); + int sumi = 0; + for (int j = 0; j < 8; ++j) { + const uint32_t sh = bq4->qs[ib32] >> (8 + 3*j); + uint32_t val = ql[j] + ((qh[j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0; + int v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + //int s = val & km; + //sumi += q8[4*j+k] * ggml_cuda_dp4a(s, 0x01010101, -126); + v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + sumi = ggml_cuda_dp4a(v4, q8[j], sumi); + } + *result += dl * __low2float(bq8_1[ib32].ds) * sumi; +} + +__device__ __forceinline__ void vec_dot_iq2_kt_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + + float scale = *(const float *)vbq; + const block_iq2_kt * bq2 = (const block_iq2_kt *)((const char *)vbq + sizeof(float)) + kbx; + + // iqs is 0...28 + const int ib32 = iqs/4; + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const int ls = iq4k_values[(bq2->scales[ib32%4] >> 4*(ib32/4)) & 0xf]; + const float dl = scale * ls * 1.05f; + auto ql = (const uint16_t *)bq2->ql; + int sumi = 0; + for (int j = 0; j < 4; ++j) { + uint32_t val = ql[4*ib32+j] + 4096; + int v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi); + v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi); + } + *result += dl * __low2float(bq8_1[ib32].ds) * sumi; +} + +__device__ __forceinline__ void vec_dot_iq3_kt_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + + float scale = *(const float *)vbq; + const block_iq3_kt * bq3 = (const block_iq3_kt *)((const char *)vbq + sizeof(float)) + kbx; + + // iqs is 0...28 + const int ib32 = iqs/4; + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const int ls = (bq3->scales[ib32%4] >> 4*(ib32/4)) & 0xf; + const float dl = scale * ls * 1.015f; + auto ql = (const uint16_t *)bq3->ql; + uint32_t mask = 0x01010101 << ib32; + const uint32_t * qh = (const uint32_t *)bq3->qh; + int sumi = 0; + for (int j = 0; j < 4; ++j) { + uint32_t val = ql[4*ib32+j] + 4096; + int v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)); + v4 |= q << 8*k; + } + uint32_t signs = __vcmpne4(qh[2*j+0] & mask, 0); + v4 = __vsub4(v4 ^ signs, signs); + sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi); + v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)); + v4 |= q << 8*k; + } + signs = __vcmpne4(qh[2*j+1] & mask, 0); + v4 = __vsub4(v4 ^ signs, signs); + sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi); + } + *result += dl * __low2float(bq8_1[ib32].ds) * sumi; +} + #define VDR_IQ4_KSS_Q8_1_MMVQ 4 #define VDR_IQ4_KSS_Q8_1_MMQ 4 @@ -1217,6 +1330,30 @@ void mul_mat_vec_iq4_ks_q8_1_cuda( iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq4_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + +void mul_mat_vec_iq2_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq2_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + +void mul_mat_vec_iq3_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq3_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq4_kss_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 17bf5ad2..e7c6e1d2 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -100,3 +100,18 @@ void mul_mat_vec_iq1_m_r4_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + +void mul_mat_vec_iq2_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + +void mul_mat_vec_iq3_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + +void mul_mat_vec_iq4_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index a13be11b..9103e3f1 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -100,6 +100,15 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_KS_R4: mul_mat_q_case<GGML_TYPE_IQ4_KS_R4>(ctx, args, stream); break; + case GGML_TYPE_IQ4_KT: + mul_mat_q_case<GGML_TYPE_IQ4_KT>(ctx, args, stream); + break; + case GGML_TYPE_IQ2_KT: + mul_mat_q_case<GGML_TYPE_IQ2_KT>(ctx, args, stream); + break; + case GGML_TYPE_IQ3_KT: + mul_mat_q_case<GGML_TYPE_IQ3_KT>(ctx, args, stream); + break; case GGML_TYPE_IQ5_KS: mul_mat_q_case<GGML_TYPE_IQ5_KS>(ctx, args, stream); break; @@ -172,6 +181,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 608de8f0..c6e6a365 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -93,6 +93,9 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -202,6 +205,9 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_KT : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0; default : return tile_x_sizes{0, 0, 0}; } } @@ -250,6 +256,9 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_KT : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0; default : return 0; } } @@ -2790,6 +2799,226 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin } +template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto shb = bxi->qs; + const auto ql = (const uint8_t *)(shb + 8); + const auto qh = ql + 64; + const uint32_t sh = shb[ib32] >> (8 + 6*j); + uint32_t offset = 4096 + ((shb[ib32] & 1) << 15); + uint32_t val1 = offset + ql[8*ib32+2*j+0] + ((qh[8*(ib32%4)+2*j+0] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 7) << 12); + uint32_t val2 = offset + ql[8*ib32+2*j+1] + ((qh[8*(ib32%4)+2*j+1] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 56) << 9); + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val1 *= ka; + val2 *= ka; + v.x |= (ggml_cuda_dp4a(val1 & km, 0x01010101, -126) & 0xff) << 8*k; + v.y |= (ggml_cuda_dp4a(val2 & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 1) + kbx0; + const int ls = (bxi->qs[threadIdx.x % 8] & 0xff) >> 1; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * (ls - 64); +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * (ls - 64); +#endif // INT8_MMA_AVAILABLE + } +} + +template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_kt * bxi = (const block_iq2_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto ql = (const uint16_t *)bxi->ql; + uint32_t val = ql[4*ib32+j] + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val *= ka; + v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + for (int k = 0; k < 4; ++k) { + val *= ka; + v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0] * 1.05f; + const block_iq2_kt * bxi = (const block_iq2_kt *)(dptr + 1) + kbx0; + int ib32 = threadIdx.x % 8; + const int ls = iq4k_values[(bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf]; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + +template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_kt * bxi = (const block_iq3_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto ql = (const uint16_t *)bxi->ql; + const auto qh = (const uint32_t *)bxi->qh; + uint32_t mask = 0x01010101 << ib32; + uint32_t val = ql[4*ib32+j] + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val *= ka; + v.x |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k; + } + auto signs = __vcmpne4(qh[2*j+0] & mask, 0); + v.x = __vsub4(v.x ^ signs, signs); + for (int k = 0; k < 4; ++k) { + val *= ka; + v.y |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k; + } + signs = __vcmpne4(qh[2*j+1] & mask, 0); + v.y = __vsub4(v.y ^ signs, signs); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0] * 1.01f; + const block_iq3_kt * bxi = (const block_iq3_kt *)(dptr + 1) + kbx0; + int ib32 = threadIdx.x % 8; + const int ls = (bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_ks_r4( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -3383,6 +3612,27 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS_R4> { }; template <int mmq_x, int mmq_y, int nwarps, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KT> { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_kt<mmq_y, nwarps, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; +}; + +template <int mmq_x, int mmq_y, int nwarps, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_KT> { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_kt<mmq_y, nwarps, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; +}; + +template <int mmq_x, int mmq_y, int nwarps, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_KT> { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_kt<mmq_y, nwarps, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; +}; + +template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_KS> { static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks<mmq_y, nwarps, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; @@ -3843,6 +4093,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KT); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 73caabab..6412be30 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -527,6 +527,15 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ4_KSS: mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; + case GGML_TYPE_IQ2_KT: + mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ3_KT: + mul_mat_vec_iq3_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ4_KT: + mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; case GGML_TYPE_IQ2_KS: mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; @@ -687,6 +696,9 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: return true; default: return false; diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu new file mode 100644 index 00000000..2d48f077 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu new file mode 100644 index 00000000..978bc6ca --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_KT); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu new file mode 100644 index 00000000..4590f919 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index a05a890e..c3c4f0bb 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6596,6 +6596,37 @@ void kernel_mul_mv_iq2_k_f32_impl( } } +struct Trellis3 { + constexpr constant static uint32_t kmask = 0x3f3f3f3f; + constexpr constant static uint32_t ka = 89226354; + constexpr constant static uint32_t kb = 64248484; + constexpr constant static uint32_t ka1 = ka*ka; + constexpr constant static uint32_t kb1 = kb*ka+kb; + constexpr constant static uint32_t ka2 = ka1*ka; + constexpr constant static uint32_t kb2 = kb1*ka+kb; + constexpr constant static uint32_t ka3 = ka2*ka; + constexpr constant static uint32_t kb3 = kb2*ka+kb; + static inline char4 gen4(uint32_t val) { + thread uint32_t aux[4] = {(ka*val + kb) & kmask, (ka1*val + kb1) & kmask, (ka2*val + kb2) & kmask, (ka3*val + kb3) & kmask}; + thread const int8_t * a8 = (thread const int8_t *)aux; + char4 result; + for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3]; + return result; + } + template <typename T4> + static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) { + thread uint32_t aux[4] = {ka*val + kb, ka1*val + kb1, ka2*val + kb2, ka3*val + kb3}; + uint32_t aux32[2]; + thread const int8_t * a8 = (thread const int8_t *)aux32; + for (int i = 0; i < 4; ++i) { + aux32[0] = aux[i] & kmask; + aux32[1] = (ka3*aux[i] + kb3) & kmask; + v1[i] = -126 + a8[0] + a8[1] + a8[2] + a8[3]; + v2[i] = -126 + a8[4] + a8[5] + a8[6] + a8[7]; + } + } +}; + struct Trellis { constexpr constant static uint32_t kmask1 = 0x8fff8fff; constexpr constant static uint32_t kmask2 = 0x3b603b60; @@ -6691,7 +6722,7 @@ void kernel_mul_mv_iq2_kt_f32_impl( float drow[N_DST]; for (int row = 0; row < N_DST; ++row) { device const float * dptr = (device const float *)(cx + row*row_size); - drow[row] = dptr[0] * 31.75f * 1.05f; + drow[row] = dptr[0] * 1.05f; } device const block_iq2_kt * x = (device const block_iq2_kt *)(cx + sizeof(float)); @@ -6706,10 +6737,10 @@ void kernel_mul_mv_iq2_kt_f32_impl( const float ls = drow[row] * iq4k_values[(sc[(it/2)%4] >> 4*(it/8)) & 0xf]; - Trellis::gen8(q2[2*it+0]+4096, v1, v2); + Trellis3::gen8(q2[2*it+0]+4096, v1, v2); auto sum = v1*y4[0] + v2*y4[1]; - Trellis::gen8(q2[2*it+1]+4096, v1, v2); + Trellis3::gen8(q2[2*it+1]+4096, v1, v2); sum += v1*y4[2] + v2*y4[3]; sum *= ls; @@ -8542,19 +8573,18 @@ template <typename type4x4> void dequantize_iq2_kt(device const block_iq2_kt * x, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 int ib32 = il/2; - half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 31.75h * 1.05h; + half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 1.05h; device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2); - half4 v1, v2; + char4 v1, v2; for (int i = 0; i < 2; ++i) { - Trellis::gen8(q2[i]+4096, v1, v2); - v1 *= scale; v2 *= scale; + Trellis3::gen8(q2[i]+4096, v1, v2); if constexpr (is_same_v<type4x4, half4x4>) { - reg[2*i+0] = v1; - reg[2*i+1] = v2; + reg[2*i+0] = {scale*(half)v1[0], scale*(half)v1[1], scale*(half)v1[2], scale*(half)v1[3]}; + reg[2*i+1] = {scale*(half)v2[0], scale*(half)v2[1], scale*(half)v2[2], scale*(half)v2[3]}; } else { - reg[2*i+0] = {(float)v1[0], (float)v1[1], (float)v1[2], (float)v1[3]}; - reg[2*i+1] = {(float)v2[0], (float)v2[1], (float)v2[2], (float)v2[3]}; + reg[2*i+0] = {scale*(float)v1[0], scale*(float)v1[1], scale*(float)v1[2], scale*(float)v1[3]}; + reg[2*i+1] = {scale*(float)v2[0], scale*(float)v2[1], scale*(float)v2[2], scale*(float)v2[3]}; } } } @@ -8586,20 +8616,20 @@ void dequantize_iq4_kt(device const block_iq4_kt * x, short il, float d, thread device const uint32_t * shb = x->qs; device const uint8_t * ql = (device const uint8_t *)(shb + 8); device const uint8_t * qh = ql + 64; - float scale = d * (((shb[ib32] & 0xff) >> 1) - 64); + const int ls = (shb[ib32] & 0xff) >> 1; + const float scale = d * (ls - 64); const uint32_t offset = 4096 + ((shb[ib32] & 1) << 15); - const int jj = ib32*8 + 4*(il%2); - ql += jj; - qh += jj%32; + ql += 8*ib32; + qh += 8*(ib32%4); uint32_t sh = (shb[ib32] >> (8 + 12*(il%2))) << 12; - const int shift = 8 - 4*(jj/32); + const int shift = 8 - 4*(ib32/4); for (int i = 0; i < 4; ++i) { uint32_t idx = ql[i] + ((qh[i] << shift) & 0xf00) + ((sh >> 3*i) & 0x7000) + offset; - auto v = (float4)Trellis::gen4(idx); - reg[i] = v * scale; + auto c4 = Trellis3::gen4(idx); + reg[i] = {scale*c4[0], scale*c4[1], scale*c4[2], scale*c4[3]}; } } @@ -8931,18 +8961,17 @@ struct DequantizerKT4 { using type4x4 = T4x4; DequantizerKT4(device const char * cx, short il = 0) : il(il) { device const float * dptr = (device const float *)cx; - d[0] = dptr[0] * 31.75f * 1.01f; - d[1] = dptr[1]; - x = (device const Block *)(dptr + 2); + d = dptr[0] * 1.01f; + x = (device const Block *)(dptr + 1); } inline void convert(thread T4x4& t) const { float4x4 tmp; - dequantize_iq4_kt(x, il, d[0], tmp); + dequantize_iq4_kt(x, il, d, tmp); for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]; } inline void convert(int64_t ind, thread T4x4& t) { float4x4 tmp; - dequantize_iq4_kt(x + ind/nl, ind%nl, d[0], tmp); + dequantize_iq4_kt(x + ind/nl, ind%nl, d, tmp); for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]; } inline void next() { @@ -8951,7 +8980,7 @@ struct DequantizerKT4 { } device const Block * x; short il; - float d[2]; + float d; }; template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(half d, device const Block *, short, thread T4x4&), bool may_not_be_aligned = false> diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 69b1b46d..cc056f89 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1596,10 +1596,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq2_kt, .from_float_ref = (ggml_from_float_t)quantize_row_iq2_kt_ref, .vec_dot = vec_dot_iq2_kt_q8_k, -#ifdef __ARM_NEON - .vec_dot_type = GGML_TYPE_F16, +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else - .vec_dot_type = GGML_TYPE_F32, + .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif .nrows = 1, .row_meta_size = 4, @@ -1613,11 +1613,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq3_kt, .from_float_ref = (ggml_from_float_t)quantize_row_iq3_kt_ref, .vec_dot = vec_dot_iq3_kt_q8_k, -#ifdef __ARM_NEON - .vec_dot_type = GGML_TYPE_F16, +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else - .vec_dot_type = GGML_TYPE_F32, + .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif +//#ifdef __ARM_NEON +// .vec_dot_type = GGML_TYPE_F16, +//#else +// .vec_dot_type = GGML_TYPE_F32, +//#endif .nrows = 1, .row_meta_size = 4, }, @@ -1630,13 +1635,13 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq4_kt, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref, .vec_dot = vec_dot_iq4_kt_q8_k, -#ifdef __ARM_NEON - .vec_dot_type = GGML_TYPE_F16, +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else - .vec_dot_type = GGML_TYPE_F32, + .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif .nrows = 1, - .row_meta_size = 8, + .row_meta_size = 4, }, [GGML_TYPE_IQ3_K] = { .type_name = "iq3_k", diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index bc7bcf8b..2ddfbe86 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -97,6 +97,134 @@ struct Trellis2 { } }; + +template <bool is_8 = false, bool is_abs = false> +struct Trellis3 { + constexpr static uint32_t ka = 0xCBAC1FED; + constexpr static uint32_t ka1 = ka*ka; + constexpr static uint32_t ka2 = ka1*ka; + constexpr static uint32_t ka3 = ka2*ka; + constexpr static uint32_t ka4 = ka3*ka; + constexpr static uint32_t ka5 = ka4*ka; + constexpr static uint32_t ka6 = ka5*ka; + constexpr static uint32_t ka7 = ka6*ka; + const __m256i mka = is_8 ? _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7) : _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3); + const __m256i shuffle = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + + inline __m256i next8(uint32_t val1, uint32_t val2) const { + __m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1)); + return _mm256_mullo_epi32(mval, mka); + } + inline __m256i next8(uint32_t val) const { + __m256i mval = _mm256_set1_epi32(val); + return _mm256_mullo_epi32(mval, mka); + } + inline __m256 gen8(uint32_t val1, uint32_t val2) const { + auto v8 = _mm256_and_si256(next8(val1, val2), _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD + auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8); +#else + auto dot = _mm256_maddubs_epi16(v8, _mm256_set1_epi32(0x01010101)); + auto i8 = _mm256_add_epi32(_mm256_set1_epi32(-126), _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif + if constexpr (is_abs) { + return _mm256_cvtepi32_ps(_mm256_sign_epi32(i8, i8)); + } else { + return _mm256_cvtepi32_ps(i8); + } + } + inline __m256 gen8(uint32_t val) const { + auto v8 = _mm256_and_si256(next8(val), _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD + auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8); +#else + auto dot = _mm256_maddubs_epi16(v8, _mm256_set1_epi32(0x01010101)); + auto i8 = _mm256_add_epi32(_mm256_set1_epi32(-126), _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif + if constexpr (is_abs) { + return _mm256_cvtepi32_ps(_mm256_sign_epi32(i8, i8)); + } else { + return _mm256_cvtepi32_ps(i8); + } + } + inline __m256i next32(const uint32_t * val) const { + const __m256i offset = _mm256_set1_epi32(-126); + __m256i aux[4]; + for (int i = 0; i < 4; ++i) { + auto i8 = _mm256_and_si256(next8(val[2*i+0], val[2*i+1]), _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD + aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8); +#else + auto dot = _mm256_maddubs_epi16(i8, _mm256_set1_epi32(0x01010101)); + aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif + } + aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + aux[0] = _mm256_packs_epi16(aux[0], aux[2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 + // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + if constexpr (is_abs) { + auto result = _mm256_permutevar8x32_epi32(aux[0], shuffle); + return _mm256_sign_epi8(result, result); + } else { + return _mm256_permutevar8x32_epi32(aux[0], shuffle); + } + } + inline __m256i next32(const uint16_t * val, uint32_t v0) const { + const __m256i offset = _mm256_set1_epi32(-126); + __m256i aux[4]; + for (int i = 0; i < 4; ++i) { + auto i8 = _mm256_and_si256(next8(v0 + val[i]), _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD + aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8); +#else + auto dot = _mm256_maddubs_epi16(i8, _mm256_set1_epi32(0x01010101)); + aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif + } + aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + aux[0] = _mm256_packs_epi16(aux[0], aux[2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 + // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + if constexpr (is_abs) { + auto result = _mm256_permutevar8x32_epi32(aux[0], shuffle); + return _mm256_sign_epi8(result, result); + } else { + return _mm256_permutevar8x32_epi32(aux[0], shuffle); + } + } + inline void next64(const uint32_t * val, __m256i * result) const { + const __m256i offset = _mm256_set1_epi32(-126); + auto vka3 = _mm256_set1_epi32(ka3); + __m256i aux[8]; + for (int i = 0; i < 4; ++i) { + auto i8_1 = next8(val[2*i+0], val[2*i+1]); + auto i8_2 = _mm256_mullo_epi32(i8_1, vka3); + i8_1 = _mm256_and_si256(i8_1, _mm256_set1_epi32(0x3f3f3f3f)); + i8_2 = _mm256_and_si256(i8_2, _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD + aux[i+0] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8_1); + aux[i+4] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8_2); +#else + auto dot1 = _mm256_maddubs_epi16(i8_1, _mm256_set1_epi32(0x01010101)); + auto dot2 = _mm256_maddubs_epi16(i8_2, _mm256_set1_epi32(0x01010101)); + aux[i+0] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot1, _mm256_set1_epi16(1))); + aux[i+4] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot2, _mm256_set1_epi16(1))); +#endif + } + for (int k = 0; k < 2; ++k) { + aux[4*k+0] = _mm256_packs_epi32(aux[4*k+0], aux[4*k+1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + aux[4*k+2] = _mm256_packs_epi32(aux[4*k+2], aux[4*k+3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + aux[4*k+0] = _mm256_packs_epi16(aux[4*k+0], aux[4*k+2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 + // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + result[k] = _mm256_permutevar8x32_epi32(aux[4*k+0], shuffle); + if constexpr (is_abs) { + result[k] = _mm256_sign_epi8(result[k], result[k]); + } + } + } +}; + void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; @@ -136,6 +264,60 @@ void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t } } +void iqk_dequantize_iq2_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq2_kt * x8[8]; + float dkt[8]; + float ls[8]; + float ls_all[64]; + uint32_t idx[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq2_kt *)(dptr + 1); + } + auto vd = _mm256_mul_ps(_mm256_set1_ps(1.05f), _mm256_loadu_ps(dkt)); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto s8 = _mm_set1_epi32(*(const uint32_t *)x8[k][i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + s8 = _mm_shuffle_epi8(values, s8); + auto s32 = _mm256_cvtepi8_epi32(s8); + _mm256_storeu_ps(ls_all + 8*k, _mm256_cvtepi32_ps(s32)); + } + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib]; + auto scales = _mm256_mul_ps(vd, _mm256_loadu_ps(ls)); + _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 8; ++k) { + const uint16_t * ql = (const uint16_t *)x8[k][i].ql; + idx[k] = ql[4*ib+j] + 4096; + } + __m256i packed[2]; + trellis.next64(idx, packed); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, packed[0]); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, packed[1]); + } + } + y += 8; // = QK_K/32; + } + } +} + template <int nrc_y> void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -198,6 +380,243 @@ void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +template <int nrc_y> +void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis3<true> trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + constexpr int k_acc = nrc_y; + + __m256 accd[k_acc]; + const block_q8_2_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); + } + + __m256i xv[4], dot[4]; + __m256 scales[2]; + + auto sum_4 = [&dot] () { + // dot[k] has 8 values from block k + // 0 1 0 1 0 1 0 1 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1])); + // 2 3 2 3 2 3 2 3 + dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3])); + // 0 1 2 3 0 1 2 3 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2])); + return _mm256_cvtepi32_ps(dot[0]); + }; + + auto compute_dot = [&dot, &xv] (const int8_t * y) { + for (int k = 0; k < 4; ++k) { + auto yv = _mm256_loadu_si256((const __m256i *)y + k); +#ifdef HAVE_FANCY_SIMD + //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); +#else + auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); + dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1)); +#endif + } + }; + + //auto m126 = _mm256_set1_ps(-126.f); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0] * 1.05f); + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + s8 = _mm_shuffle_epi8(values, s8); + auto s32 = _mm256_cvtepi8_epi32(s8); + auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32)); + auto scales_l = _mm256_castps256_ps128(all_scales); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[0] = _mm256_set_m128(scales_l, scales_l); + scales[1] = _mm256_set_m128(scales_h, scales_h); + for (int i128 = 0; i128 < 2; ++i128) { + //for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k); + for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_2_x4& yb = y[iy][2*i+i128]; + auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16)); + dy = _mm256_mul_ps(scales[i128], dy); + auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); + //auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); + compute_dot(yb.qs); + accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]); + //accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + +void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + + Trellis3<false, true> trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq3_kt * x8[8]; + float dkt[8]; + float ls[8]; + float ls_all[64]; + uint32_t idx[8]; + uint32_t sign_bits[16]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq3_kt *)(dptr + 1); + } + auto vd = _mm256_mul_ps(_mm256_set1_ps(1.01f), _mm256_loadu_ps(dkt)); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto s8 = _mm_set1_epi32(*(const uint32_t *)x8[k][i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + auto s32 = _mm256_cvtepi8_epi32(s8); + _mm256_storeu_ps(ls_all + 8*k, _mm256_cvtepi32_ps(s32)); + } + auto mask = _mm256_set1_epi8(1); + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib]; + auto scales = _mm256_mul_ps(vd, _mm256_loadu_ps(ls)); + _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 8; ++k) { + const uint16_t * ql = (const uint16_t *)x8[k][i].ql; + idx[k] = ql[4*ib+j] + 4096; + auto qh = (const uint32_t *)x8[k][i].qh; + sign_bits[k+0] = qh[2*j+0]; + sign_bits[k+8] = qh[2*j+1]; + } + __m256i packed[2]; + trellis.next64(idx, packed); + auto signs1 = _mm256_loadu_si256((const __m256i *)sign_bits+0); + auto signs2 = _mm256_loadu_si256((const __m256i *)sign_bits+1); + signs1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs1, mask), mask), _mm256_set1_epi8(1)); + signs2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs2, mask), mask), _mm256_set1_epi8(1)); + packed[0] = _mm256_sign_epi8(packed[0], signs1); + packed[1] = _mm256_sign_epi8(packed[1], signs2); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, packed[0]); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, packed[1]); + } + mask = _mm256_slli_epi16(mask, 1); + } + y += 8; // = QK_K/32; + } + } +} + +template <int nrc_y> +void mul_mat_iq3_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis3<true, true> trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + + constexpr int k_acc = nrc_y; + + __m256 accd[k_acc]; + const block_q8_2_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); + } + + __m256i xv[4], sv[4], dot[4]; + __m256 scales[2]; + + auto sum_4 = [&dot] () { + // dot[k] has 8 values from block k + // 0 1 0 1 0 1 0 1 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1])); + // 2 3 2 3 2 3 2 3 + dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3])); + // 0 1 2 3 0 1 2 3 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2])); + return _mm256_cvtepi32_ps(dot[0]); + }; + + auto compute_dot = [&dot, &xv, &sv] (const int8_t * y) { + for (int k = 0; k < 4; ++k) { + auto yv = _mm256_loadu_si256((const __m256i *)y + k); +#ifdef HAVE_FANCY_SIMD + //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], _mm256_sign_epi8(yv, sv[k])); +#else + auto p = _mm256_maddubs_epi16(xv[k], _mm256_sign_epi8(yv, sv[k])); + dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1)); +#endif + } + }; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0] * 1.01f); + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + auto ql = (const uint16_t *)x[i].ql; + auto sign_bits = _mm256_loadu_si256((const __m256i *)x[i].qh); + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + auto s32 = _mm256_cvtepi8_epi32(s8); + auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32)); + auto scales_l = _mm256_castps256_ps128(all_scales); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[0] = _mm256_set_m128(scales_l, scales_l); + scales[1] = _mm256_set_m128(scales_h, scales_h); + auto mask = _mm256_set1_epi8(1); + for (int i128 = 0; i128 < 2; ++i128) { + for (int k = 0; k < 4; ++k) { + xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096); + sv[k] = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sign_bits, mask), mask), _mm256_set1_epi8(1)); + mask = _mm256_slli_epi16(mask, 1); + } + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_2_x4& yb = y[iy][2*i+i128]; + auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16)); + dy = _mm256_mul_ps(scales[i128], dy); + auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); + compute_dot(yb.qs); + accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + inline __m256 abs_ps(__m256 vals) { // Clear sign-bit of all the 32-bit floats in vals __m256 sign_bit = _mm256_set1_ps(-0.0f); @@ -315,19 +734,67 @@ void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq4_kt * x8[8]; + float dkt[8]; + int32_t ls[8]; + uint32_t idx0[8], idx[16]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq4_kt *)(dptr + 1); + } + auto vd = _mm256_loadu_ps(dkt); + + for (int i = 0; i < nb; ++i) { + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) { + ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64; + idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096; + } + auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls))); + _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + int shift1 = 8 - 4*(ib/4); + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); + const uint8_t * qh = ql + kNumGroups; + const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j); + idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; + } + _mm256_storeu_si256((__m256i *)y[ib].qs+j, trellis.next32(idx)); + } + } + y += 8; // = QK_K/32; + } + + } +} + void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; constexpr int kNumGroups = 64; - Trellis2 trellis; + Trellis3 trellis; union { __m256 vec; float val[8]; } s_helper; union { __m256i vec; uint32_t val[8]; } o_helper; for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); - auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto d = _mm256_set1_ps(dptr[0]); auto dav = _mm256_set1_ps(dptr[1]); const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); @@ -349,8 +816,8 @@ void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - auto x_val1 = _mm256_fmadd_ps(scale1, trellis_gen8(trellis.next8(val1, val3)), dav); - auto x_val2 = _mm256_fmadd_ps(scale2, trellis_gen8(trellis.next8(val2, val4)), dav); + auto x_val1 = _mm256_fmadd_ps(scale1, trellis.gen8(val1, val3), dav); + auto x_val2 = _mm256_fmadd_ps(scale2, trellis.gen8(val2, val4), dav); _mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1); _mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2); @@ -365,12 +832,112 @@ void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t } template <int nrc_y> +void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + union { __m256i vec; uint32_t val[8]; } o_helper; + + constexpr int k_acc = nrc_y; + + __m256 accd[k_acc]; + const block_q8_2_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); + } + + uint32_t values[64]; + __m256i xv[4], dot[4]; + __m256 scales[2]; + + auto sum_4 = [&dot] () { + // dot[k] has 8 values from block k + // 0 1 0 1 0 1 0 1 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1])); + // 2 3 2 3 2 3 2 3 + dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3])); + // 0 1 2 3 0 1 2 3 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2])); + return _mm256_cvtepi32_ps(dot[0]); + }; + + auto compute_dot = [&dot, &xv] (const int8_t * y) { + for (int k = 0; k < 4; ++k) { + auto yv = _mm256_loadu_si256((const __m256i *)y + k); +#ifdef HAVE_FANCY_SIMD + //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); +#else + auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); + dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1)); +#endif + } + }; + + //auto m126 = _mm256_set1_ps(-126.f); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0]); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); + iscales = _mm256_sub_epi32(iscales, _mm256_set1_epi32(64)); + auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(iscales)); + auto scales_l = _mm256_castps256_ps128(all_scales); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[0] = _mm256_set_m128(scales_l, scales_l); + scales[1] = _mm256_set_m128(scales_h, scales_h); + o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); + for (int ib = 0; ib < 4; ++ib) { + for (int j = 0; j < 4; ++j) { + const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + } + } + for (int i128 = 0; i128 < 2; ++i128) { + //for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k); + for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_2_x4& yb = y[iy][2*i+i128]; + auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16)); + dy = _mm256_mul_ps(scales[i128], dy); + auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); + //auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); + compute_dot(yb.qs); + accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]); + //accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + +template <int nrc_y> void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); const int nb = n/QK_K; constexpr int kNumGroups = 64; - Trellis2 trellis; + Trellis3 trellis; union { __m256 vec; float val[8]; } s_helper; union { __m256i vec; uint32_t val[8]; } o_helper; @@ -389,7 +956,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); - auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto d = _mm256_set1_ps(dptr[0]); auto dav = dptr[1]; const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); @@ -413,8 +980,8 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3))); - auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4))); + auto x_val1 = _mm256_mul_ps(scale1, trellis.gen8(val1, val3)); + auto x_val2 = _mm256_mul_ps(scale2, trellis.gen8(val2, val4)); if constexpr (nrc_y == 1) { auto y1 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+ 0); auto y2 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+128); @@ -446,11 +1013,37 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { - if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F32) { + if (ne00%QK_K != 0) return false; + + func16 = nullptr; + + if (typeA == GGML_TYPE_IQ4_KT) { + if (typeB == GGML_TYPE_Q8_2_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_2_x4_T, kernels); + return true; + } return false; } - func16 = nullptr; + if (typeA == GGML_TYPE_IQ2_KT) { + if (typeB == GGML_TYPE_Q8_2_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_q8_2_x4_T, kernels); + return true; + } + return false; + } + + if (typeA == GGML_TYPE_IQ3_KT) { + if (typeB == GGML_TYPE_Q8_2_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_q8_2_x4_T, kernels); + return true; + } + return false; + } + + if (ggml_type(typeB) != GGML_TYPE_F32) { + return false; + } switch (typeA) { case GGML_TYPE_IQ2_KT: @@ -470,11 +1063,11 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat } -bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) { +bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, [[maybe_unused]] size_t stride_y, int nrc_x) { switch (type) { - case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; - case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; - case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; + case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break; + case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt_q80_r8(n, vx, bx, y, nrc_x); break; + case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break; default: return false; } return true; @@ -992,21 +1585,423 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +struct Trellis3 { + constexpr static uint32_t ka = ;0xCBAC1FED; + constexpr static uint32_t ka1 = ka*ka; + constexpr static uint32_t ka2 = ka1*ka; + constexpr static uint32_t ka3 = ka2*ka; + const uint32x4_t mka = uint32x4_t{ka, ka1, ka2, ka3}; + const uint8x16_t shuffle = load_shuffle(); + + inline uint32x4x2_t next8(uint32_t val1, uint32_t val2) const { + uint32x4x2_t result{vdupq_n_u32(val1), vdupq_n_u32(val2)}; + result.val[0] = vmulq_u32(mka, result.val[0]); + result.val[1] = vmulq_u32(mka, result.val[1]); + return result; + } + inline int8x16x2_t next32(const uint32_t * val) const { + int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)}; + for (int i = 0; i < 2; ++i) { + auto i8 = next8(val[4*i+0], val[4*i+1]); + i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + i8 = next8(val[4*i+2], val[4*i+3]); + i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2)); + } + return result; + } + inline int8x16x2_t next32(const uint16_t * val, uint32_t v0) const { + auto vka3 = vdupq_n_u32(ka3), vkb3 = vdupq_n_u32(kb3); + int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)}; + int8x16x2_t i8; + for (int i = 0; i < 2; ++i) { + i8.val[0] = vmulq_u32(mka, vdupq_n_u32(val[2*i+0]+v0)); + i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]); + i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + i8.val[0] = vmulq_u32(mka, vdupq_n_u32(val[2*i+1]+v0)); + i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]); + i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2)); + } + return result; + } + inline int8x16x4_t next64(const uint32_t * val) const { + auto vka3 = vdupq_n_u32(ka3), vkb3 = vdupq_n_u32(kb3); + int8x16x4_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126), vdupq_n_s8(-126), vdupq_n_s8(-126)}; + for (int i = 0; i < 2; ++i) { + auto i8_1 = next8(val[4*i+0], val[4*i+1]); + int8x16x2_t i8_2{vmlaq_u32(vkb3, vka3, i8_1.val[0]), vmlaq_u32(vkb3, vka3, i8_1.val[1])}; + i8_1.val[0] = vandq_u32(i8_1.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8_1.val[1] = vandq_u32(i8_1.val[1], vdupq_n_u32(0x3f3f3f3f)); + i8_2.val[0] = vandq_u32(i8_2.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8_2.val[1] = vandq_u32(i8_2.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s1_1 = vpaddq_s8(vreinterpretq_s8_u32(i8_1.val[0]), vreinterpretq_s8_u32(i8_1.val[1])); + auto s1_2 = vpaddq_s8(vreinterpretq_s8_u32(i8_2.val[0]), vreinterpretq_s8_u32(i8_2.val[1])); + i8_1 = next8(val[4*i+2], val[4*i+3]); + i8_2.val[0] = vmlaq_u32(vkb3, vka3, i8_1.val[0]); + i8_2.val[1] = vmlaq_u32(vkb3, vka3, i8_1.val[1]); + i8_1.val[0] = vandq_u32(i8_1.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8_1.val[1] = vandq_u32(i8_1.val[1], vdupq_n_u32(0x3f3f3f3f)); + i8_2.val[0] = vandq_u32(i8_2.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8_2.val[1] = vandq_u32(i8_2.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s2_1 = vpaddq_s8(vreinterpretq_s8_u32(i8_1.val[0]), vreinterpretq_s8_u32(i8_1.val[1])); + auto s2_2 = vpaddq_s8(vreinterpretq_s8_u32(i8_2.val[0]), vreinterpretq_s8_u32(i8_2.val[1])); + result.val[i+0] = vaddq_s8(result.val[i+0], vpaddq_s8(s1_1, s2_1)); + result.val[i+2] = vaddq_s8(result.val[i+2], vpaddq_s8(s1_2, s2_2)); + } + return result; + } + static uint8x16_t load_shuffle() { + static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; + return vld1q_u8(k_shuffle); + } +}; + +void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq4_kt * x8[8]; + float dkt[8]; + int32_t ls[8]; + uint32_t idx0[8], idx[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq4_kt *)(dptr + 1); + } + auto vd = vld1q_f32_x2(dkt); + + for (int i = 0; i < nb; ++i) { + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) { + ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64; + idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096; + } + auto scales1 = vmulq_f32(vd.val[0], vcvtq_f32_s32(vld1q_s32(ls+0))); + auto scales2 = vmulq_f32(vd.val[1], vcvtq_f32_s32(vld1q_s32(ls+4))); + vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1)); + vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2)); + int shift1 = 8 - 4*(ib/4); + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); + const uint8_t * qh = ql + kNumGroups; + const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j); + idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; + } + vst1q_s8_x2(y[ib].qs+32*j, trellis.next32(idx)); + } + } + y += 8; // = QK_K/32; + } + } +} + +template <int nrc_y> +void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + union { uint32x4x2_t vec; uint32_t val[8]; } o_helper; + + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; + + float32x4_t accd[k_acc]; + + const block_q8_0_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_0_x4 *)info.src1_row(iy); + } + + uint32_t values[16]; + int8x16x2_t xv[8]; + int32x4x4_t dot; + + auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) { + for (int k = 0; k < 4; ++k) { + auto yv = vld1q_s8_x2(y + 32*k); + dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]); + } + dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]); + dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]); + return vpaddq_s32(dot.val[0], dot.val[2]); + }; + + int32x4x2_t shifts = {int32x4_t{4, 1, -2, -5}, int32x4_t{-8, -11, -14, -17}}; + + float32x4x2_t scales; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = vdupq_n_f32(dptr[0]); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0); + + for (int i = 0; i < nb; ++i) { + auto vshb = vld1q_u32_x2(x[i].qs); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + auto iscales1 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(0xff)), 1)); + auto iscales2 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(0xff)), 1)); + iscales1 = vaddq_s32(iscales1, vdupq_n_s32(-64)); + iscales2 = vaddq_s32(iscales2, vdupq_n_s32(-64)); + scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(iscales1)); + scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(iscales2)); + o_helper.vec.val[0] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); + o_helper.vec.val[1] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); + for (int ib = 0; ib < 4; ++ib) { + auto vql1 = vmovl_u8(vld1_u8(ql+8*ib)); + auto vql2 = vmovl_u8(vld1_u8(ql+8*ib+32)); + auto vqh = vmovl_u8(vld1_u8(qh+8*ib)); + vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 8))); + vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 4))); + auto sh1_u32 = vdupq_n_u32(shb[ib+0]); + auto sh2_u32 = vdupq_n_u32(shb[ib+4]); + auto sh1 = vcombine_u16(vmovn_u32(vshlq_u32(sh1_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh1_u32, shifts.val[1]))); + auto sh2 = vcombine_u16(vmovn_u32(vshlq_u32(sh2_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh2_u32, shifts.val[1]))); + vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0x7000), sh1)); + vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0x7000), sh2)); + auto oh1 = vdupq_n_u32(o_helper.val[ib+0]); + auto oh2 = vdupq_n_u32(o_helper.val[ib+4]); + vst1q_u32(values +0, vaddq_u32(vmovl_u16(vget_low_u16 (vql1)), oh1)); + vst1q_u32(values +4, vaddq_u32(vmovl_u16(vget_high_u16(vql1)), oh1)); + vst1q_u32(values +8, vaddq_u32(vmovl_u16(vget_low_u16 (vql2)), oh2)); + vst1q_u32(values+12, vaddq_u32(vmovl_u16(vget_high_u16(vql2)), oh2)); + xv[ib+0] = trellis.next32(values+0); + xv[ib+4] = trellis.next32(values+8); + } + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_0_x4& ybl = y[iy][2*i+0]; + const block_q8_0_x4& ybh = y[iy][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + auto sumil = compute_dot(ybl.qs, xv+0); + auto sumih = compute_dot(ybh.qs, xv+4); + if constexpr (nrc_y == 1) { + accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil)); + accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih)); + } else { + accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil)); + accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih)); + } + } + } + + if constexpr (nrc_y == 1) { + info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1]))); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(accd[iy])); + } + } + } +} + +void iqk_dequantize_iq2_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto values = vld1q_s8(iq4k_values); + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq2_kt * x8[8]; + float dkt[8]; + float ls[8], ls_all[64]; + uint32_t idx[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0] * 1.05f; + x8[k] = (const block_iq2_kt *)(dptr + 1); + } + auto vd = vld1q_f32_x2(dkt); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto u32 = *(const uint32_t *)x8[k][i].scales; + auto s8_u32 = uint32x2_t{u32, u32 >> 4}; + s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f)); + auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32)); + auto s16 = vmovl_s8(s8); + vst1q_f32(ls_all + 8*k + 0, vcvtq_f32_s32(vmovl_s16(vget_low_s16(s16)))); + vst1q_f32(ls_all + 8*k + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16)))); + } + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib]; + auto scales1 = vmulq_f32(vd.val[0], vld1q_f32(ls+0)); + auto scales2 = vmulq_f32(vd.val[1], vld1q_f32(ls+4)); + vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1)); + vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2)); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 8; ++k) { + const uint16_t * ql = (const uint16_t *)x8[k][i].ql; + idx[k] = ql[4*ib+j] + 4096; + } + vst1q_s8_x4(y[ib].qs+64*j, trellis.next64(idx)); + } + } + y += 8; // = QK_K/32; + } + } +} + +template <int nrc_y> +void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto values = vld1q_s8(iq4k_values); + + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; + + float32x4_t accd[k_acc]; + + const block_q8_0_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_0_x4 *)info.src1_row(iy); + } + + int8x16x2_t xv[8]; + int32x4x4_t dot; + + auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) { + for (int k = 0; k < 4; ++k) { + auto yv = vld1q_s8_x2(y + 32*k); + dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]); + } + dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]); + dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]); + return vpaddq_s32(dot.val[0], dot.val[2]); + }; + + float32x4x2_t scales; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = vdupq_n_f32(dptr[0]*1.05f); + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0); + + for (int i = 0; i < nb; ++i) { + auto u32 = *(const uint32_t *)x[i].scales; + auto s8_u32 = uint32x2_t{u32, u32 >> 4}; + s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f)); + auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32)); + auto s16 = vmovl_s8(s8); + scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16)))); + scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16)))); + const uint16_t * ql = (const uint16_t *)x[i].ql; + if constexpr (nrc_y == 1) { + const block_q8_0_x4& ybl = y[0][2*i+0]; + const block_q8_0_x4& ybh = y[0][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + int32x4x4_t suml = {}; + int32x4x4_t sumh = {}; + for (int ib = 0; ib < 4; ++ib) { + auto xl = trellis.next32(ql + 4*ib + 0, 4096); + auto xh = trellis.next32(ql + 4*ib + 16, 4096); + auto yl = vld1q_s8_x2(ybl.qs + 32*ib); + auto yh = vld1q_s8_x2(ybh.qs + 32*ib); + suml.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]); + sumh.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]); + } + auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]); + auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]); + auto sl = vpaddq_s32(sl1, sl2); + auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]); + auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]); + auto sh = vpaddq_s32(sh1, sh2); + accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl)); + accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh)); + } else { + for (int k = 0; k < 8; ++k) xv[k] = trellis.next32(ql + 4*k, 4096); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_0_x4& ybl = y[iy][2*i+0]; + const block_q8_0_x4& ybh = y[iy][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + auto sumil = compute_dot(ybl.qs, xv+0); + auto sumih = compute_dot(ybh.qs, xv+4); + if constexpr (nrc_y == 1) { + accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil)); + accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih)); + } else { + accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil)); + accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih)); + } + } + } + } + + if constexpr (nrc_y == 1) { + info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1]))); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(accd[iy])); + } + } + } +} + } bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { - //if (ne00%QK_K == 0 && ggml_type(typeB) == GGML_TYPE_F32 && ggml_type(typeA) == GGML_TYPE_IQ4_KT) { - // IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels); - // func16 = nullptr; - // return true; - //} - if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F16) { + if (ne00%QK_K != 0) return false; + + if (ggml_type(typeA) == GGML_TYPE_IQ4_KT) { + if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_0_x4_T, kernels); + func16 = nullptr; + return true; + } return false; } - func16 = nullptr; + if (ggml_type(typeA) == GGML_TYPE_IQ2_KT) { + if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_q8_0_x4_T, kernels); + func16 = nullptr; + return true; + } + return false; + } + + if (ggml_type(typeB) != GGML_TYPE_F16) { + return false; + } switch (typeA) { case GGML_TYPE_IQ2_KT: @@ -1022,14 +2017,16 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat return false; } + func16 = nullptr; + return true; } bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) { switch (type) { - case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break; + case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break; case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break; - case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break; + case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break; default: return false; } diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 6925e6a6..f718e43e 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -236,9 +236,6 @@ struct MulMat { static inline ggml_type is_dequant_better(ggml_type type, int nrc_y) { #ifdef __AVX2__ switch (type) { - case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; - case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; - case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ2_S : return nrc_y >= 16 ? GGML_TYPE_Q8_K_R8 : type; @@ -267,13 +264,16 @@ struct MulMat { case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } #else switch (type) { - case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type; + case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type; - case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type; + case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } #endif @@ -815,7 +815,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: - return ggml_type(typeB) == GGML_TYPE_F32 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false; + return iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16); case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index abd4be61..0384e49a 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -7397,7 +7397,7 @@ void dequantize_row_ms_i2s(const void * vx, float * y, int64_t k) { } namespace { -template <int block_size, int group_size, int num_bits, bool is_abs = false> +template <int block_size, int group_size, int num_bits, bool is_abs = false, bool is_int = false> class QuantizerIQKT { static_assert(group_size == 8 || group_size == 4); static_assert(block_size >= 8 && block_size%8 == 0); @@ -7408,7 +7408,7 @@ public: constexpr static int kNg = kBlockSize/kGroupSize; constexpr static int kNblock = kSuperBlockSize/kBlockSize; constexpr static int kNumVal = 1 << num_bits; // i.e, 16 bits per group of 8 - constexpr static float kScale = 31.75f; + constexpr static float kScale = is_int ? 1.f : 31.75f; constexpr static bool kVerbose = false; QuantizerIQKT(int num_clusters, int num_neighbours, int offset = 4096); @@ -7419,17 +7419,32 @@ public: inline float find_best_inverse_scale(const float * xb, const float * weight, const int * best_idx) const; static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) { - constexpr uint32_t ka = 89226354; - constexpr uint32_t kb = 64248484; - constexpr uint32_t kmask = 0x8fff8fff; - constexpr uint32_t km32 = 0x3b603b60; uint32_t x = i + offset; - for (int k = 0; k < kGroupSize; ++k) { - x = ka*x + kb; - uint32_t s = (x & kmask) ^ km32; - float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); - if constexpr (is_abs) result[k] = scale*std::abs(val); - else result[k] = scale*val; + if constexpr (is_int) { + constexpr uint32_t ka = 0xCBAC1FED; + uint32_t s; + auto i8 = (const int8_t *)&s; + for (int k = 0; k < kGroupSize; ++k) { + x = ka*x; + s = x & 0x3f3f3f3f; + if constexpr (is_abs) { + result[k] = scale*std::abs(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + } else { + result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + } + } + } else { + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + for (int k = 0; k < kGroupSize; ++k) { + x = ka*x + kb; + uint32_t s = (x & kmask) ^ km32; + float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); + if constexpr (is_abs) result[k] = scale*std::abs(val); + else result[k] = scale*val; + } } } @@ -7478,14 +7493,15 @@ private: float m_mid[4*kGroupSize]; }; -template <int block_size, int group_size, int num_bits, bool is_abs> -QuantizerIQKT<block_size, group_size, num_bits, is_abs>::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) { +template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int> +QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) { m_values.resize(kNumVal*kGroupSize); float * data = m_values.data(); for (int i = 0; i < kNumVal; ++i) { set_values(i, data, kScale, offset); data += kGroupSize; } + if (num_clusters == 0) return; // Make 128 clusters. // Note: we get a slightly better result by using 64 clusters // at the expense of almost doubling the quantization time. @@ -7494,8 +7510,8 @@ QuantizerIQKT<block_size, group_size, num_bits, is_abs>::QuantizerIQKT(int num_c m_in_cluster = finalize_clusters(num_neighbours, m_values, m_clusters, m_c_values); } -template <int block_size, int group_size, int num_bits, bool is_abs> -std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_scale( +template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int> +std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::find_best_scale( const float * xb, const float * weight, const int * best_idx) const { float sumqx = 0, sumq2 = 0; #ifdef __AVX2__ @@ -7527,8 +7543,8 @@ std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>: return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f); } -template <int block_size, int group_size, int num_bits, bool is_abs> -float QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_inverse_scale( +template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int> +float QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::find_best_inverse_scale( const float * xb, const float * weight, const int * best_idx) const { float sumqx = 0, sumx2 = 0; #ifdef __AVX2__ @@ -7560,8 +7576,8 @@ float QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_inverse return sumx2 > 0 ? sumqx/sumx2 : 0.f; } -template <int block_size, int group_size, int num_bits, bool is_abs> -void QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const { +template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int> +void QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const { if (!d) { std::memset(best_idx, 0, kNg*sizeof(int)); return; @@ -7739,8 +7755,8 @@ void QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_match(fl #endif } -template <int block_size, int group_size, int num_bits, bool is_abs> -std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::finalize_clusters(int num_neighbours, +template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int> +std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::finalize_clusters(int num_neighbours, const std::vector<float>& values, const std::vector<float>& clusters, std::vector<std::vector<float>>& c_values) { int ncluster = clusters.size()/kGroupSize; std::vector<std::vector<int>> p_in_cluster(ncluster); @@ -7826,8 +7842,8 @@ std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits, is return p_in_cluster; } -template <int block_size, int group_size, int num_bits, bool is_abs> -std::vector<float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid) { +template <int block_size, int group_size, int num_bits, bool is_abs, bool is_int> +std::vector<float> QuantizerIQKT<block_size, group_size, num_bits, is_abs, is_int>::cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid) { constexpr int ndim = kGroupSize; GGML_ASSERT(points.size() % ndim == 0); int npoint = points.size() / ndim; @@ -7995,7 +8011,7 @@ std::vector<float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::clus // ========================================== iq2_kt ==================================================== -using QuantizerIQ2KT = QuantizerIQKT<32, 8, 16>; +using QuantizerIQ2KT = QuantizerIQKT<32, 8, 16, false, true>; const QuantizerIQ2KT& iq2kt_quantizer() { static std::mutex mutex; @@ -8006,7 +8022,7 @@ const QuantizerIQ2KT& iq2kt_quantizer() { } void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights, - float * qtmp) { + int * all_idx) { constexpr float kSigmaScale = 2.0f; using Q = QuantizerIQ2KT; @@ -8025,6 +8041,11 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); + float amax_row = 0; + for (int j = 0; j < n_per_row; ++j) { + amax_row = std::max(amax_row, std::abs(x[j])); + } + float amax_scale = 0, max_scale = 0; for (int ibl = 0; ibl < nblock; ++ibl) { @@ -8042,9 +8063,10 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f float ax = std::abs(xb[j]); amax = std::max(amax, ax); } - quantizer.find_best_match( amax/96.f, xb, weight, best_idx); + float scale_0 = std::max(90.f, 124.f*amax/amax_row); + quantizer.find_best_match( amax/scale_0, xb, weight, best_idx); auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); - quantizer.find_best_match(-amax/96.f, xb, weight, best_idx + Q::kNg); + quantizer.find_best_match(-amax/scale_0, xb, weight, best_idx + Q::kNg); auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx + Q::kNg); auto idx = best_idx; @@ -8052,12 +8074,7 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f else { scales[ib] = dm; idx += Q::kNg; } - auto qt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; - for (int ig = 0; ig < Q::kNg; ++ig) { - auto q = quantizer.values() + idx[ig]*Q::kGroupSize; - for (int j = 0; j < Q::kGroupSize; ++j) qt[j] = q[j]; - qt += Q::kGroupSize; - } + for (int ig = 0; ig < Q::kNg; ++ig) all_idx[(ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize + ig] = idx[ig]; float abs_scale = std::abs(scales[ib]); if (abs_scale > amax_scale) { @@ -8080,20 +8097,22 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f float sumqx = 0, sumq2 = 0; for (int ibl = 0; ibl < nblock; ++ibl) { const float * xb = x + ibl*Q::kSuperBlockSize; - const float * qb = qtmp + ibl*Q::kSuperBlockSize; const float * wb = all_weights + ibl*Q::kSuperBlockSize; auto scales = all_scales + ibl*Q::kNblock; for (int ib = 0; ib < Q::kNblock; ++ib) { int ls = best_index_iq4nl(iq4k_values, id*scales[ib]); float dl = iq4k_values[ls]; - for (int j = 0; j < Q::kBlockSize; ++j) { - float q = dl*qb[j]; - sumqx += wb[j]*xb[j]*q; - sumq2 += wb[j]*q*q; + for (int ig = 0; ig < Q::kNg; ++ig) { + auto qb = quantizer.values() + Q::kGroupSize*all_idx[(ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize + ig]; + for (int j = 0; j < Q::kGroupSize; ++j) { + int jj = ig*Q::kGroupSize + j; + float q = dl*qb[j]; + sumqx += wb[jj]*xb[jj]*q; + sumq2 += wb[jj]*q*q; + } } xb += Q::kBlockSize; wb += Q::kBlockSize; - qb += Q::kBlockSize; } } if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { @@ -8129,6 +8148,26 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f float dl = d*ls; quantizer.find_best_match(dl, xb, weight, best_idx); + auto prev_idx = all_idx + (ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize; + + float mse1 = 0, mse2 = 0; + for (int ig = 0; ig < Q::kNg; ++ig) { + auto q1 = quantizer.values() + Q::kGroupSize*prev_idx[ig]; + auto q2 = quantizer.values() + Q::kGroupSize*best_idx[ig]; + for (int j = 0; j < Q::kGroupSize; ++j) { + int jj = ig*Q::kGroupSize + j; + float diff1 = xb[jj] - dl*q1[j]; + float diff2 = xb[jj] - dl*q2[j]; + mse1 += weight[jj]*diff1*diff1; + mse2 += weight[jj]*diff2*diff2; + } + } + if (mse1 < mse2) { + for (int ig = 0; ig < Q::kNg; ++ig) best_idx[ig] = prev_idx[ig]; + } else { + for (int ig = 0; ig < Q::kNg; ++ig) prev_idx[ig] = best_idx[ig]; + } + for (int j = 0; j < Q::kNg; ++j) { qs[j] = best_idx[j]; auto xl = xb + Q::kGroupSize*j; @@ -8196,10 +8235,10 @@ size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_p auto row_size = ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row); std::vector<float> scales(n_per_row/QuantizerIQ2KT::kBlockSize); std::vector<float> weights(n_per_row); - std::vector<float> xtmp(n_per_row); + std::vector<int> idx(n_per_row/QuantizerIQ2KT::kGroupSize); char * qrow = (char *)dst; for (int64_t row = 0; row < nrows; ++row) { - quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data()); + quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), idx.data()); src += n_per_row; qrow += row_size; } @@ -8209,7 +8248,7 @@ size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_p void dequantize_row_iq2_kt(const block_iq2_kt * x, float * y, int64_t k) { assert(k % QuantizerIQ2KT::kSuperBlockSize == 0); #ifdef __AVX2__ - if (iqk_dequantize_ktquants(GGML_TYPE_IQ2_KT, k, x, 0, y, 0, 1)) return; + //if (iqk_dequantize_ktquants(GGML_TYPE_IQ2_KT, k, x, 0, y, 0, 1)) return; #endif const int nb = k / QuantizerIQ2KT::kSuperBlockSize; const float * dptr = (const float *)x; @@ -8254,7 +8293,7 @@ void vec_dot_iq2_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx namespace { -using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true>; +using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true, true>; const QuantizerIQ3KT& iq3kt_quantizer() { static std::mutex mutex; std::lock_guard<std::mutex> lock(mutex); @@ -8465,7 +8504,7 @@ size_t quantize_iq3_kt(const float * src, void * dst, int64_t nrows, int64_t n_p void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) { #ifdef __AVX2__ - if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return; + //if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return; #endif using Q = QuantizerIQ3KT; constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize; @@ -8521,7 +8560,7 @@ void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx namespace{ -using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15>; +using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15, false, true>; const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) { static std::mutex mutex; @@ -8536,6 +8575,14 @@ const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) { return *quantizer1; } +const QuantizerIQ4KT& iq4kt_dequantizer() { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + static std::unique_ptr<QuantizerIQ4KT> dequantizer; + if (!dequantizer) dequantizer = std::make_unique<QuantizerIQ4KT>(0, 0, 4096); + return *dequantizer; +} + void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights) { constexpr float kSigmaScale = 2.0f; @@ -8546,7 +8593,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f float * dptr = (float *)vy; - block_iq4_kt * y = (block_iq4_kt *)(dptr + 2); + block_iq4_kt * y = (block_iq4_kt *)(dptr + 1); auto& quantizer1 = iq4kt_quantizer(); auto& quantizer2 = iq4kt_quantizer(true); @@ -8555,13 +8602,10 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); - float amax_row = 0, row_av = 0; + float amax_row = 0; for (int j = 0; j < n_per_row; ++j) { - row_av += x[j]; amax_row = std::max(amax_row, std::abs(x[j])); } - row_av /= n_per_row; - dptr[1] = row_av; if (!amax_row) { dptr[0] = 0.f; std::memset(y, 0, nblock*sizeof(block_iq4_kt)); @@ -8584,7 +8628,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; float amax = 0; for (int j = 0; j < Q::kBlockSize; ++j) { - xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av; + xaux[j] = xbl[ib*Q::kBlockSize+j]; float ax = std::abs(xaux[j]); amax = std::max(amax, ax); } @@ -8593,7 +8637,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f continue; } float best = 0; - float scale_0 = std::max(92.f, 127.f*amax/amax_row); + float scale_0 = std::max(90.f, 124.f*amax/amax_row); for (int itry = -kNtry; itry <= kNtry; ++itry) { quantizer1.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx); auto [dp, score_p] = quantizer1.find_best_scale(xaux, weight, best_idx); @@ -8664,7 +8708,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f for (int ib = 0; ib < Q::kNblock; ++ib) { auto& quantizer = y[ibl].qs[ib] & 1 ? quantizer2 : quantizer1; const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; - for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av; + for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j]; int ls = nearest_int(id*scales[ib]); ls = std::min(ls, 63); *(uint8_t *)(shb + ib) = ((ls + 64) << 1) | (shb[ib] & 1); @@ -8724,7 +8768,7 @@ size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_p void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { #ifdef __AVX2__ - if (iqk_dequantize_ktquants(GGML_TYPE_IQ4_KT, k, x, 0, y, 0, 1)) return; + //if (iqk_dequantize_ktquants(GGML_TYPE_IQ4_KT, k, x, 0, y, 0, 1)) return; #endif using Q = QuantizerIQ4KT; assert(k % Q::kSuperBlockSize == 0); @@ -8732,23 +8776,20 @@ void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { const int nb = k / Q::kSuperBlockSize; const float * dptr = (const float *)x; const float d = dptr[0] * Q::kScale; - const float row_av = dptr[1]; - x = (const block_iq4_kt *)(dptr + 2); - auto& deq = iq4kt_quantizer(); + x = (const block_iq4_kt *)(dptr + 1); + auto& deq = iq4kt_dequantizer(); for (int ibl = 0; ibl < nb; ++ibl) { auto shb = x[ibl].qs; auto ql = (const uint8_t *)(shb + Q::kNblock); auto qh = ql + kNumGroups; for (int ib = 0; ib < Q::kNblock; ++ib) { int offset = shb[ib] & 1 ? 32768 + 4096 : 4096; - //auto& deq = shb[ib] & 1 ? deq2 : deq1; int ls = int((shb[ib] & 0xff) >> 1) - 64; float sl = d * ls; for (int ig = 0; ig < Q::kNg; ++ig) { int jj = ib*Q::kNg+ig; uint16_t idx = ql[jj] | ((qh[jj%(kNumGroups/2)] << (8 - 4*(jj/(kNumGroups/2)))) & 0xf00) | (((shb[ib] >> (8 + 3*ig)) & 7) << 12); deq.set_values(idx, y, sl, offset); - for (int j = 0; j < Q::kGroupSize; ++j) y[j] += row_av; y += Q::kGroupSize; } } diff --git a/src/llama.cpp b/src/llama.cpp index dfd53337..af8ef9be 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18627,6 +18627,7 @@ static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) { new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4|| new_type == GGML_TYPE_IQ4_KS_R4 || new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 || new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4|| + new_type == GGML_TYPE_IQ2_KT || new_type == GGML_TYPE_IQ3_KT || new_type == GGML_TYPE_IQ4_KT || new_type == GGML_TYPE_IQ5_KS || new_type == GGML_TYPE_IQ5_KS_R4) { if (nx % QK_K != 0) { LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type)); |