diff options
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 57 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 23 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cuh | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 9 |
5 files changed, 84 insertions, 16 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 516e74d8..fbc52aa9 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -677,6 +677,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_K> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> { static constexpr int qk = QK_K; static constexpr int qr = QR4_XS; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index f388e9f3..ed7e4bd0 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -595,6 +595,33 @@ static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_ } } +template<typename dst_t> +static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq3_k * x = (const block_iq3_k *) vx; + + const int tid = threadIdx.x; + int ib128 = tid/16; // 0 or 1 + int il = tid%16; // 0...15 + dst_t * y = yy + i*QK_K + 128*ib128 + 2*il; + const float d = (float)x[i].d * 1.01f; //1.0125f; + const uint16_t sh = x[i].scales_h >> (8*ib128 + (il/8)); + const float dl1 = d * ((2*((x[i].scales_l[4*ib128+0] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x01) ? -1 : 1)); + const float dl2 = d * ((2*((x[i].scales_l[4*ib128+1] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x04) ? -1 : 1)); + const float dl3 = d * ((2*((x[i].scales_l[4*ib128+2] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x10) ? -1 : 1)); + const float dl4 = d * ((2*((x[i].scales_l[4*ib128+3] >> 4*(il/8)) & 0xf) + 1) * ((sh & 0x40) ? -1 : 1)); + const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; + const uint8_t * qh = x[i].qh + 2*il; + const int16_t extra = x[i].extra >> (8*ib128 + (il/8)); + for (int j = 0; j < 2; ++j) { + const uint8_t h = qh[j] >> (4*(ib128%2)); + y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]; + y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]; + y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]; + y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]; + } +} template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { @@ -726,21 +753,27 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t } template<typename dst_t> -static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { +static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; - dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y); + dequantize_block_iq2_k<<<nb, 32, 0, stream>>>(vx, y); } template<typename dst_t> -static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { +static void dequantize_row_iq3_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; - dequantize_block_iq5_k<<<nb, 32, 0, stream>>>(vx, y); + dequantize_block_iq3_k<<<nb, 32, 0, stream>>>(vx, y); } template<typename dst_t> -static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { +static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; - dequantize_block_iq2_k<<<nb, 32, 0, stream>>>(vx, y); + dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq5_k<<<nb, 32, 0, stream>>>(vx, y); } template <typename src_t, typename dst_t> @@ -807,12 +840,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ2_K: + return dequantize_row_iq2_k_cuda; + case GGML_TYPE_IQ3_K: + return dequantize_row_iq3_k_cuda; case GGML_TYPE_IQ4_K: return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ5_K: return dequantize_row_iq5_k_cuda; - case GGML_TYPE_IQ2_K: - return dequantize_row_iq2_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: @@ -864,12 +899,14 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ2_K: + return dequantize_row_iq2_k_cuda; + case GGML_TYPE_IQ3_K: + return dequantize_row_iq3_k_cuda; case GGML_TYPE_IQ4_K: return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ5_K: return dequantize_row_iq5_k_cuda; - case GGML_TYPE_IQ2_K: - return dequantize_row_iq2_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 3c2277ed..bf7b2aa7 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -245,9 +245,6 @@ __device__ __forceinline__ float vec_dot_iq5_k_q8_1( return d5 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2); } -#define VDR_IQ2_K_Q8_1_MMVQ 4 -#define VDR_IQ2_K_Q8_1_MMQ 4 - static const __device__ uint32_t iq2k_table[512] = { 0xe1e1e1e1, 0xe1e1e1f3, 0xe1e1e101, 0xe1e1e111, 0xe1e1f3e1, 0xe1e1f3f3, 0xe1e1f301, 0xe1e1f311, 0xe1e101e1, 0xe1e101f3, 0xe1e10101, 0xe1e10111, 0xe1e111e1, 0xe1e111f3, 0xe1e11101, 0xe1e11111, @@ -319,6 +316,9 @@ __device__ __forceinline__ int int_from_table_4(const uint8_t * a8, const int * return values[a8[0] | (a8[1] << 2) | (a8[2] << 4) | (a8[3] << 6)]; } +#define VDR_IQ2_K_Q8_1_MMVQ 4 +#define VDR_IQ2_K_Q8_1_MMQ 4 + __device__ __forceinline__ float vec_dot_iq2_k_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { @@ -378,8 +378,18 @@ __device__ __forceinline__ float vec_dot_iq2_k_q8_1( } +#define VDR_IQ3_K_Q8_1_MMVQ 4 +#define VDR_IQ3_K_Q8_1_MMQ 4 + +// TODO +__device__ __forceinline__ float vec_dot_iq3_k_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + return 0; + } +} // namespace + void mul_mat_vec_iq2_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -387,6 +397,13 @@ void mul_mat_vec_iq2_k_q8_1_cuda( iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K, VDR_IQ2_K_Q8_1_MMVQ, vec_dot_iq2_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +void mul_mat_vec_iq3_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_K, VDR_IQ3_K_Q8_1_MMVQ, vec_dot_iq3_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + void mul_mat_vec_iq4_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 14e5c1c7..9a33af0d 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -4,6 +4,10 @@ void mul_mat_vec_iq2_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); +void mul_mat_vec_iq3_k_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + void mul_mat_vec_iq4_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 93c8ac29..56bf3ebe 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -432,15 +432,18 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ4_XS: mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ2_K: + mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; + case GGML_TYPE_IQ3_K: + mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ4_K: mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; case GGML_TYPE_IQ5_K: mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; - case GGML_TYPE_IQ2_K: - mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; case GGML_TYPE_IQ3_S: mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; |