diff options
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 31 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 41 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cuh | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cu | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cuh | 6 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt.cu | 81 |
8 files changed, 178 insertions, 1 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 38b52fd0..15485f60 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -572,6 +572,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KS> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ1_KT> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> { static constexpr int qk = QK_K; static constexpr int qr = QR4_XS; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index c8e02a83..8c03ae1b 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -359,6 +359,26 @@ float __device__ __forceinline__ trellis_next(uint32_t& val) { } template<typename dst_t> +static __global__ void dequantize_block_iq1_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float scale = *(const float *)cx; + const block_iq1_kt * x = (const block_iq1_kt *)(cx + sizeof(float)); + const int64_t i = ii - (row*n_per_row)/QK_K; + + const int64_t tid = threadIdx.x; + const int64_t ib = tid; // 0...31 + dst_t * y = yy + ii*QK_K + 8*ib; + uint32_t idx = (x[i].ql[ib] | ((x[i].qh[ib%16] << (8 - 4*(ib/16))) & 0xf00) | ((x[i].sh[ib/4] << (8 - (ib%4))) & 0x1000)) + 4096; + const float dl = scale * iq4k_values[x[i].sh[ib/4] & 0xf]; + for (int j = 0; j < 8; ++j) { + y[j] = dl * trellis_next_int(idx); + } +} + +template<typename dst_t> static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { int64_t ii = blockIdx.x; @@ -1506,6 +1526,13 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_ } template<typename dst_t> +static void dequantize_row_iq1_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = k / QK_K; + dequantize_block_iq1_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ1_KT, n_per_row)); +} + +template<typename dst_t> static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; const int nb = k / QK_K; @@ -1888,6 +1915,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_q6_K_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ1_KT: + return dequantize_row_iq1_kt_cuda; case GGML_TYPE_IQ2_KT: return dequantize_row_iq2_kt_cuda; case GGML_TYPE_IQ3_KT: @@ -1987,6 +2016,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q6_K_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ1_KT: + return dequantize_row_iq1_kt_cuda; case GGML_TYPE_IQ2_KT: return dequantize_row_iq2_kt_cuda; case GGML_TYPE_IQ3_KT: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index a669390d..c7f5dfb4 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -443,6 +443,39 @@ __device__ __forceinline__ void vec_dot_iq4_kt_q8_1( *result += dl * __low2float(bq8_1[ib32].ds) * sumi; } +__device__ __forceinline__ void vec_dot_iq1_kt_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + + float scale = *(const float *)vbq; + const block_iq1_kt * bq1 = (const block_iq1_kt *)((const char *)vbq + sizeof(float)) + kbx; + + // iqs is 0...28 + const int ib32 = iqs/4; + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const int ls = iq4k_values[bq1->sh[ib32] & 0xf]; + const float dl = scale * ls; + int sumi = 0; + for (int j = 0; j < 4; ++j) { + uint32_t val = bq1->ql[4*ib32+j] + 4096 + ((bq1->qh[4*(ib32%4)+j] << (8 - 4*(ib32/4))) & 0xf00) + ((bq1->sh[ib32] << (8 - j)) & 0x1000); + int v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi); + v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi); + } + *result += dl * __low2float(bq8_1[ib32].ds) * sumi; +} + __device__ __forceinline__ void vec_dot_iq2_kt_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { @@ -1350,6 +1383,14 @@ void mul_mat_vec_iq4_kt_q8_1_cuda( iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq1_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq1_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq2_kt_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index d14c3541..5d62d02e 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -111,6 +111,11 @@ void mul_mat_vec_iq1_m_r4_q8_1_cuda( const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); +void mul_mat_vec_iq1_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + void mul_mat_vec_iq2_kt_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index cde5d044..d417fdc0 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -109,6 +109,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_KT: mul_mat_q_case<GGML_TYPE_IQ4_KT>(ctx, args, stream); break; + case GGML_TYPE_IQ1_KT: + mul_mat_q_case<GGML_TYPE_IQ1_KT>(ctx, args, stream); + break; case GGML_TYPE_IQ2_KT: mul_mat_q_case<GGML_TYPE_IQ2_KT>(ctx, args, stream); break; @@ -211,6 +214,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 21b50082..aaf02fab 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -100,6 +100,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: @@ -218,6 +219,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ1_KT : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ3_KT : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0; @@ -275,6 +277,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ1_KT : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ3_KT : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0; @@ -4176,9 +4179,10 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); -extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_KT); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index d0746031..012b3e5e 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -533,6 +533,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ4_KSS: mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; + case GGML_TYPE_IQ1_KT: + mul_mat_vec_iq1_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; case GGML_TYPE_IQ2_KT: mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; @@ -704,6 +707,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: + case GGML_TYPE_IQ1_KT: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt.cu new file mode 100644 index 00000000..1a3590e5 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt.cu @@ -0,0 +1,81 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq1_kt * bxi = (const block_iq1_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + uint32_t val = bxi->ql[kqsx] + ((bxi->qh[kqsx%16] << (8 - 4*(kqsx/16))) & 0xf00) + ((bxi->sh[kqsx/4] << (8 - (kqsx%4))) & 0x1000) + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val *= ka; + v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + for (int k = 0; k < 4; ++k) { + val *= ka; + v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0]; + const block_iq1_kt * bxi = (const block_iq1_kt *)(dptr + 1) + kbx0; + const int ls = iq4k_values[bxi->sh[threadIdx.x % 8] & 0xf]; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + +template <int mmq_x, int mmq_y, int nwarps, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_KT> { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_kt<mmq_y, nwarps, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ1_KT); |