diff options
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 32 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 12 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/vecdotq.cuh | 47 |
4 files changed, 98 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 7ea93264..8549c4e5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -670,6 +670,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> { static constexpr int qk = QK_K; static constexpr int qr = QR3_S; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 66e68a52..e7732cf5 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -521,6 +521,28 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst } } +template<typename dst_t> +static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const int64_t i = blockIdx.x; + const block_iq4_k * x = (const block_iq4_k *)vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[i].qs + 16*ib + 4*il; + const float d = (float)x[i].d; + const uint8_t sh = x[i].scales_h[ib/2] >> 4*(ib%2); + const float d1 = d * (((x[i].scales_l[ib] & 0xf) | ((sh << 4) & 0x30)) - 32); + const float d2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32); + const int8_t * values1 = iq4k_values + 16*((x[i].extra >> (2*ib+0)) & 1); + const int8_t * values2 = iq4k_values + 16*((x[i].extra >> (2*ib+1)) & 1); + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d1 * values1[q4[j] & 0xf]; + y[j+16] = d2 * values2[q4[j] >> 4]; + } +} + template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); @@ -650,6 +672,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y); } +template<typename dst_t> +static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y); +} + template <typename src_t, typename dst_t> static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; @@ -714,6 +742,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ4_K: + return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: @@ -765,6 +795,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ4_K: + return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index b44000cd..5da32d99 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -24,6 +24,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_IQ2_BN ? vec_dot_iq2_bn_q8_1 : type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : + type == GGML_TYPE_IQ4_K ? vec_dot_iq4_k_q8_1 : type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : nullptr; } @@ -46,6 +47,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ : type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ : type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ : + type == GGML_TYPE_IQ4_K ? VDR_IQ4_K_Q8_1_MMVQ : 1; } @@ -343,6 +345,13 @@ static void mul_mat_vec_iq4_xs_q8_1_cuda( mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +static void mul_mat_vec_iq4_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + static void mul_mat_vec_iq3_s_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -431,6 +440,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ4_XS: mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ4_K: + mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ3_S: mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 1248eacd..9f2b2300 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1227,3 +1227,50 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds); return d * sumi; } + +static __device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values, + int & val1, int & val2) { + + uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32; + aux32 = q4 & 0x0f0f0f0f; + const uint8_t * values = all_values + 16*(shift & 1); + uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8); + uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8); + val1 = v1 | (v2 << 16); + aux32 = (q4 >> 4) & 0x0f0f0f0f; + values = all_values + 8*(shift & 2); + v1 = values[q8[0]] | (values[q8[1]] << 8); + v2 = values[q8[2]] | (values[q8[3]] << 8); + val2 = v1 | (v2 << 16); +} + +#define VDR_IQ4_K_Q8_1_MMVQ 4 +#define VDR_IQ4_K_Q8_1_MMQ 4 + +static __device__ __forceinline__ float vec_dot_iq4_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq4_k * bq4 = (const block_iq4_k *) vbq + kbx; + const uint8_t * all_values = (const uint8_t *)iq4k_values; + + // iqs is 0...28 + const int ib32 = iqs/4; + // Why iqs/4 ? + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32; + const uint16_t extra = bq4->extra >> 2*ib32; + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 4; ++j) { + const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16); + get_int_from_table_16_shift(aux32, extra, all_values, v1, v2); + sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1); + sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2); + } + const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds); + const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2); + const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32; + return d * (sumi1 * ls1 + sumi2 * ls2); +} + |