diff options
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 54 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 41 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cuh | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cu | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cuh | 73 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 3 |
7 files changed, 187 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 0a7f7f83..a04a1929 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -600,6 +600,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ5_KS> { + static constexpr int qk = QK_K; + static constexpr int qr = QR5_XS; + static constexpr int qi = QI5_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ6_K> { static constexpr int qk = QK_K; static constexpr int qr = QR6_XS; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 8383f2d3..5afe8c74 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -696,6 +696,46 @@ static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_ } } + +template<typename dst_t> +static __global__ void dequantize_block_iq5_ks(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float d = *(const float *)cx; + const block_iq5_ks * x = (const block_iq5_ks *)(cx + sizeof(float)); + const int64_t i = ii - (row*n_per_row)/QK_K; + + const int tid = threadIdx.x; + int ib64 = tid/8; // 0...3 + int il = tid%8; // 0...7 + dst_t * y = yy + ii*QK_K + 64*ib64 + 2*il; + const float dl1 = d * ((int)(x[i].scales[2*ib64+0] & 254) - 127); + const float dl2 = d * ((int)(x[i].scales[2*ib64+1] & 254) - 127); + const uint8_t * qs = x[i].qs + 32*ib64 + 2*il; + const uint8_t * qh = x[i].qh + 2*il; + auto values1 = iq5nl_values + ((x[i].scales[2*ib64+0] & 1) << 5); + auto values2 = iq5nl_values + ((x[i].scales[2*ib64+1] & 1) << 5); + if constexpr (std::is_same_v<dst_t, nv_bfloat16>) { + for (int j = 0; j < 2; ++j) { + const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); + y[j+ 0] = __float2bfloat16(dl1 * values1[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4)]); + y[j+16] = __float2bfloat16(dl1 * values1[(qs[j+16] & 0xf) | ((h2 & 1) << 4)]); + y[j+32] = __float2bfloat16(dl2 * values2[(qs[j+ 0] >> 4) | ((h1 & 2) << 3)]); + y[j+48] = __float2bfloat16(dl2 * values2[(qs[j+16] >> 4) | ((h2 & 2) << 3)]); + } + } else { + for (int j = 0; j < 2; ++j) { + const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); + y[j+ 0] = dl1 * values1[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4)]; + y[j+16] = dl1 * values1[(qs[j+16] & 0xf) | ((h2 & 1) << 4)]; + y[j+32] = dl2 * values2[(qs[j+ 0] >> 4) | ((h1 & 2) << 3)]; + y[j+48] = dl2 * values2[(qs[j+16] >> 4) | ((h2 & 2) << 3)]; + } + } +} + template<typename dst_t> static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -1009,6 +1049,14 @@ static void dequantize_row_iq4_ks_cuda(const void * vx, dst_t * y, const int64_t } template<typename dst_t> +static void dequantize_row_iq5_ks_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ5_KS, n_per_row); + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq5_ks<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size); +} + +template<typename dst_t> static void dequantize_row_iq4_kss_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KSS, n_per_row); @@ -1140,6 +1188,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { return dequantize_row_iq4_kss_cuda<nv_bfloat16>; case GGML_TYPE_IQ4_KS: return dequantize_row_iq4_ks_cuda<nv_bfloat16>; + case GGML_TYPE_IQ5_KS: + return dequantize_row_iq5_ks_cuda<nv_bfloat16>; case GGML_TYPE_IQ4_K: return dequantize_row_iq4_k_cuda<nv_bfloat16>; case GGML_TYPE_IQ5_K: @@ -1202,6 +1252,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq4_ks_cuda; case GGML_TYPE_IQ4_KSS: return dequantize_row_iq4_kss_cuda; + case GGML_TYPE_IQ5_KS: + return dequantize_row_iq5_ks_cuda; case GGML_TYPE_IQ2_KS: return dequantize_row_iq2_ks_cuda; case GGML_TYPE_IQ2_K: @@ -1273,6 +1325,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq4_ks_cuda; case GGML_TYPE_IQ4_KSS: return dequantize_row_iq4_kss_cuda; + case GGML_TYPE_IQ5_KS: + return dequantize_row_iq5_ks_cuda; case GGML_TYPE_IQ2_KS: return dequantize_row_iq2_ks_cuda; case GGML_TYPE_IQ2_K: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 576c387d..6a2db725 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -328,6 +328,39 @@ __device__ __forceinline__ float vec_dot_iq5_k_q8_1( return d5 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2); } +__device__ __forceinline__ float vec_dot_iq5_ks_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + float scale = *(const float *)vbq; + const block_iq5_ks * bq5 = (const block_iq5_ks *)((const char *)vbq + sizeof(float)) + kbx; + const uint8_t * all_values = (const uint8_t *)iq5nl_values; + + int i4 = iqs/4; // 0...7. Blocks of 16 index is 4*(i4/2) + (i4%2) + (0 and 2) + + const int32_t * q8_1 = (const int *)bq8_1[2*(i4/2)+0].qs + 4*(i4%2); + const int32_t * q8_2 = (const int *)bq8_1[2*(i4/2)+1].qs + 4*(i4%2); + const uint32_t * q4 = (const uint32_t *)bq5->qs + 8*(i4/2) + 4*(i4%2); + const uint32_t * qh = (const uint32_t *)bq5->qh + 4*(i4%2); + const uint8_t * values1 = all_values + ((bq5->scales[2*(i4/2)+0] & 1) << 5); + const uint8_t * values2 = all_values + ((bq5->scales[2*(i4/2)+1] & 1) << 5); + uint32_t aux32[2]; + const uint8_t * a8 = (const uint8_t *)aux32; + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 4; ++j) { + uint32_t h = qh[j] >> 2*(i4/2); + aux32[0] = ((q4[j] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010); + aux32[1] = ((q4[j] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010); + v1 = int_from_table(a8+0, values1); + v2 = int_from_table(a8+4, values2); + sumi1 = ggml_cuda_dp4a(v1, q8_1[j], sumi1); + sumi2 = ggml_cuda_dp4a(v2, q8_2[j], sumi2); + } + const int ls1 = (bq5->scales[2*(i4/2)+0] & 254) - 127; + const int ls2 = (bq5->scales[2*(i4/2)+1] & 254) - 127; + return scale * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2); +} + #define VDR_IQ6_K_Q8_1_MMVQ 4 #define VDR_IQ6_K_Q8_1_MMQ 4 @@ -799,6 +832,14 @@ void mul_mat_vec_iq5_k_q8_1_cuda( iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq5_ks_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_KS, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq6_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 1f55ddb9..b81d2114 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -26,6 +26,11 @@ void mul_mat_vec_iq5_k_q8_1_cuda( const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); +void mul_mat_vec_iq5_ks_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + void mul_mat_vec_iq6_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 7bee10cb..2f7a9bfd 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -94,6 +94,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_KS: mul_mat_q_case<GGML_TYPE_IQ4_KS>(ctx, args, stream); break; + case GGML_TYPE_IQ5_KS: + mul_mat_q_case<GGML_TYPE_IQ5_KS>(ctx, args, stream); + break; case GGML_TYPE_IQ2_KS: mul_mat_q_case<GGML_TYPE_IQ2_KS>(ctx, args, stream); break; @@ -150,6 +153,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ3_K: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 1da9a67a..72fa9f13 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -88,6 +88,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ6_K: return MMQ_Q8_1_DS_LAYOUT_D4; default: @@ -187,6 +188,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ5_KS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_KS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ3_K : return MMQ_DP4A_TXS_Q8_0_16; @@ -231,6 +233,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ5_KS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_KS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ3_K : return MMQ_MMA_TILE_X_K_Q3_K; @@ -2794,6 +2797,67 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin } } +template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ5_KS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + auto values = iq5nl_values; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0]; + const block_iq5_ks * bxi = (const block_iq5_ks *)(dptr + 1) + kbx0; + + int qh = get_int_b4(bxi->qh, kqsx); + + #pragma unroll + for (int l = 0; l < qstep/2; ++l) { + + const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); + aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((bxi->scales[2*l+0] & 1) * 0x20202020); + aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((bxi->scales[2*l+1] & 1) * 0x20202020); + qh >>= 2; + + const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]); + const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 16*l + 8] = *(const int *)&val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1; +#endif // INT8_MMA_AVAILABLE + } + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ((bxi->scales[kqsx] & 254) - 127); +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + kqsx] = d * ((bxi->scales[kqsx] & 254) - 127); +#endif // INT8_MMA_AVAILABLE + } +} + template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq6_k( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -3139,6 +3203,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS> { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; }; +template <int mmq_x, int mmq_y, int nwarps, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_KS> { + static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks<mmq_y, nwarps, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; +}; + template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup> static __device__ void mul_mat_q_process_tile( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, @@ -3581,6 +3653,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index c6b6ef72..14fe2547 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -530,6 +530,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ5_K: mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; + case GGML_TYPE_IQ5_KS: + mul_mat_vec_iq5_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; case GGML_TYPE_IQ6_K: mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; |