From 76b97c80645362ac65a2e33043fd8d46bdaf8c56 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 16 Oct 2024 15:18:26 +0300 Subject: Adding IQ4_KSS: 4.0 bpw quants (#89) * iq4_kss: WIP * iq4_kss: CUDA dequantize works So we can run perplexity. Sadly, the result does not look good on the bpw vs quantization error plot. * iq4_kss: slightly better quantization * iq4_kss: another small quantization improvement * iq4_kss: CUDA works TG-128 performance is very decent with 131 t/s for LLaMA-3.1-8B. In comparison, we have 123 t/s for q4_0 and 128 t/s for iq4_ks. I.e., the reduced model size more than offsets the additional bit fiddling required for iq4_kss. * iq4_kss: new bit arrangement - CUDA and Zen4 work Did not lose performance on CUDA. Zen4 is decent, but not great: PP-512(LLaMA-3.1-8B) = 163 t/s. TG-128 is of course better than other 4-bit quants due to smaller model size. We get 14.5 t/s @ 8 threads. * iq4_kss: ARM_NEON. Predictably very slow * iq4_kss: Metal PP is not too bad - just 10% slower than q4_0. But TG is 30% slower, i.e., predictably bad. * iq4_kss: somewhat faster Metal dot product 45.75 t/s -> 48.75 t/s. Still 22% slower than q4_0 * iq4_kss: AVX2 Bad, but better than I expected. PP-512(LLaMA-3.1-8B) = 167 t/s on the Ryzen-5950X. I.e., with 32 AVX2 threads we get the performance of 16 Zen4 threads. * iq4_kss: very slightly faster Metal dot product 48.7 t/s -> 49.3 t/s --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/common.cuh | 7 +++++++ ggml/src/ggml-cuda/convert.cu | 43 +++++++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/iqk_mmvq.cu | 36 ++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/iqk_mmvq.cuh | 4 ++++ ggml/src/ggml-cuda/mmvq.cu | 3 +++ 5 files changed, 93 insertions(+) (limited to 'ggml/src/ggml-cuda') diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a6a9c3d3..a5658a24 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -543,6 +543,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI4_XS; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 1e4421b1..e9d15b5d 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -638,6 +638,37 @@ static __global__ void dequantize_block_iq4_ks(const void * __restrict__ vx, dst } } +template +static __global__ void dequantize_block_iq4_kss(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float scale = *(const float *)cx; + const block_iq4_kss * x = (const block_iq4_kss *)(cx + sizeof(float)); + const int64_t i = ii - (row*n_per_row)/QK_K; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + ii*QK_K + 32*ib + 4*il; + const uint32_t * q4 = x[i].qs + 4*ib; + uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6); + uint8_t ls = (s32 | (s32 >> 15)) & 0xff; + const float d = scale * ((ls & 254) - 127); + const int8_t * values = iq4k_values + ((ls & 1) << 4); + uint32_t aux32[2]; + aux32[0] = q4[il] & 0xfffefffe; + aux32[0] ^= (aux32[0] >> 1); + aux32[1] = ((aux32[0] >> 4) & 0x0f0f0f0f); + aux32[0] &= 0x0f0f0f0f; + const uint8_t * aux8 = (const uint8_t *)aux32; + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * values[aux8[j+0]]; + y[j+16] = d * values[aux8[j+4]]; + } +} + template static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int64_t i = blockIdx.x; @@ -980,6 +1011,14 @@ static void dequantize_row_iq4_ks_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq4_ks<<>>(vx, y, n_per_row, row_size); } +template +static void dequantize_row_iq4_kss_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KSS, n_per_row); + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_kss<<>>(vx, y, n_per_row, row_size); +} + template static void dequantize_row_iq2_ks_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -1152,6 +1191,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ4_KS: return dequantize_row_iq4_ks_cuda; + case GGML_TYPE_IQ4_KSS: + return dequantize_row_iq4_kss_cuda; case GGML_TYPE_IQ2_KS: return dequantize_row_iq2_ks_cuda; case GGML_TYPE_IQ2_K: @@ -1225,6 +1266,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ4_KS: return dequantize_row_iq4_ks_cuda; + case GGML_TYPE_IQ4_KSS: + return dequantize_row_iq4_kss_cuda; case GGML_TYPE_IQ2_KS: return dequantize_row_iq2_ks_cuda; case GGML_TYPE_IQ2_K: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 9ca219e4..dec54b5e 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -239,6 +239,35 @@ __device__ __forceinline__ float vec_dot_iq4_ks_q8_1( return dl * __low2float(bq8_1[ib32].ds) * sumi; } +#define VDR_IQ4_KSS_Q8_1_MMVQ 4 +#define VDR_IQ4_KSS_Q8_1_MMQ 4 + +__device__ __forceinline__ float vec_dot_iq4_kss_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + float scale = *(const float *)vbq; + const block_iq4_kss * bq4 = (const block_iq4_kss *)((const char *)vbq + sizeof(float)) + kbx; + const uint8_t * all_values = (const uint8_t *)iq4k_values; + + // iqs is 0...28 + const int ib32 = iqs/4; // Why iqs/4 ? + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32; + uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6); + uint8_t ls = (s32 | (s32 >> 15)) & 0xff; + const float dl = scale * ((ls & 254) - 127); + int v1, v2; + int sumi = 0; + for (int j = 0; j < 4; ++j) { + uint32_t aux32 = q4[j] & 0xfffefffe; + aux32 ^= (aux32 >> 1); + get_int_from_table_16_shift(aux32, ls & 1, all_values, v1, v2); + sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi); + sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi); + } + return dl * __low2float(bq8_1[ib32].ds) * sumi; +} + #define VDR_IQ5_K_Q8_1_MMVQ 4 #define VDR_IQ5_K_Q8_1_MMQ 4 @@ -703,6 +732,13 @@ void mul_mat_vec_iq4_ks_q8_1_cuda( iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +void mul_mat_vec_iq4_kss_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + void mul_mat_vec_iq2_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 3a93a1b6..0678c026 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -32,6 +32,10 @@ void mul_mat_vec_iq4_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); +void mul_mat_vec_iq4_kss_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + void mul_mat_vec_iq2_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index e312b266..107caf45 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -462,6 +462,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ4_KS: mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ4_KSS: + mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ2_KS: mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; -- cgit v1.2.3