diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-06-05 07:24:31 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-05 07:24:31 +0300 |
commit | 7e79665a31129597634bcef403512aaf4fcdeef9 (patch) | |
tree | f4ffecdadfba5cf770fc0f88426d77ff0bb5a471 | |
parent | f6d5fbdc5780b6dca770c896b8463de3239c7f8b (diff) |
CUDA implementation for IQ1_S_R4 (#492)
* iq1_s_r4: CUDA dequantize
* iq1_s_r4: CUDA GEMV
* iq1_s_r4: MMQ on CUDA
Requires Turing or better (will fall back to dequantize+cuBLAS on older cards).
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-cuda.cu | 1 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 49 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 40 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cuh | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cu | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cuh | 80 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_r4.cu | 5 |
9 files changed, 198 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 7f4b01d4..de9816fe 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3476,6 +3476,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ5_KS_R4: + case GGML_TYPE_IQ1_S_R4: return true; default: return false; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 896ba0df..bc7fadb0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -516,6 +516,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S_R4> { + static constexpr int qk = 32; + static constexpr int qr = 2; + static constexpr int qi = 4; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> { static constexpr int qk = QK_K; static constexpr int qr = QR1_M; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 644ee316..b2f77e09 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -527,6 +527,41 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_ } template<typename dst_t> +static __global__ void dequantize_block_iq1_s_r4(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + + int64_t nblock = n_per_row/32; + int64_t row = (8*ii)/nblock; + int64_t row4 = row/4; + int64_t ir = row%4; + int64_t ibl = (8*ii)%nblock; + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + + const half * dptr = (const half *)((const char *)vx + 4*row4*row_size); + const float d = (float)dptr[ir]; + const block_iq1_s_r4 * x = (const block_iq1_s_r4 *)(dptr + 4) + ibl; + dst_t * y = yy + 256*ii + 32*ib + 8*il; + + float dl = d*(2*((x[ib].qh[ir] >> 12) & 7) + 1); + float delta = dl * (x[ib].qh[ir] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA); + + uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = iq1s_grid_gpu[x[ib].qs[4*il+ir] | (((x[ib].qh[ir] >> 3*il) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; + + if constexpr (std::is_same_v<dst_t, nv_bfloat16>) { + for (int j = 0; j < 8; ++j) y[j] = __float2bfloat16(dl*q[j] + delta); + } else { + for (int j = 0; j < 8; ++j) y[j] = dl*q[j] + delta; + } +} + +template<typename dst_t> static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int64_t i = blockIdx.x; @@ -1399,6 +1434,14 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t } template<typename dst_t> +static void dequantize_row_iq1_s_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ1_S_R4, n_per_row); + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq1_s_r4<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size); +} + +template<typename dst_t> static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; const int nb = (k + QK_K - 1) / QK_K; @@ -1651,6 +1694,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { return dequantize_row_iq5_k_r4_cuda<nv_bfloat16>; case GGML_TYPE_IQ5_KS_R4: return dequantize_row_iq5_ks_r4_cuda<nv_bfloat16>; + case GGML_TYPE_IQ1_S_R4: + return dequantize_row_iq1_s_r4_cuda<nv_bfloat16>; default: return nullptr; } @@ -1699,6 +1744,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq3_xxs_cuda; case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_cuda; + case GGML_TYPE_IQ1_S_R4: + return dequantize_row_iq1_s_r4_cuda; case GGML_TYPE_IQ1_M: return dequantize_row_iq1_m_cuda; case GGML_TYPE_IQ1_BN: @@ -1790,6 +1837,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_xxs_cuda; case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_cuda; + case GGML_TYPE_IQ1_S_R4: + return dequantize_row_iq1_s_r4_cuda; case GGML_TYPE_IQ1_M: return dequantize_row_iq1_m_cuda; case GGML_TYPE_IQ1_BN: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index ae11ae14..2340b54a 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -353,6 +353,38 @@ __device__ __forceinline__ void vec_dot_iq4_ks_r4_q8_1( } } +// TODO +__device__ __forceinline__ void vec_dot_iq1_s_r4_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + const half * dptr = (const half *)vbq; + const block_iq1_s_r4 * bq1 = (const block_iq1_s_r4 *)(dptr + 4) + kbx; + + // iqs is 0 or 2 + const float d8 = __low2float(bq8_1->ds); + const int32_t * q8 = (const int *)bq8_1->qs; + + int32_t grid32[2]; + const int * igrid = (const int *)grid32; + + int minus = 0; + for (int k = 0; k < 4; ++k) minus = ggml_cuda_dp4a(0x01010101, q8[4*(iqs/2)+k], minus); + + for (int i = 0; i < 4; ++i) { + float dl = (float)dptr[i]*(2*((bq1->qh[i] >> 12) & 7) + 1) * d8; + float ml = dl * (bq1->qh[i] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA); + grid32[0] = iq1s_grid_gpu[bq1->qs[4*iqs+i] | (((bq1->qh[i] >> 3*iqs) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; + int sumi = ggml_cuda_dp4a(igrid[0], q8[4*(iqs/2)+0], ggml_cuda_dp4a(igrid[1], q8[4*(iqs/2)+1], 0)); + grid32[0] = iq1s_grid_gpu[bq1->qs[4*iqs+i+4] | (((bq1->qh[i] >> (3*iqs+3)) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; + sumi = ggml_cuda_dp4a(igrid[0], q8[4*(iqs/2)+2], ggml_cuda_dp4a(igrid[1], q8[4*(iqs/2)+3], sumi)); + result[i] += dl * sumi + ml * minus; + } +} + #define VDR_IQ4_KS_Q8_1_MMVQ 4 #define VDR_IQ4_KS_Q8_1_MMQ 4 @@ -1106,6 +1138,14 @@ void mul_mat_vec_iq4_ks_r4_q8_1_cuda( iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS_R4, 2, vec_dot_iq4_ks_r4_q8_1, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq1_s_r4_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S_R4, 2, vec_dot_iq1_s_r4_q8_1, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq5_k_r4_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 5278ef1e..1e4257e8 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -90,3 +90,8 @@ void mul_mat_vec_iq5_ks_r4_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + +void mul_mat_vec_iq1_s_r4_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 2f7a9bfd..60c2037f 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -85,6 +85,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ1_S: mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream); break; + case GGML_TYPE_IQ1_S_R4: + mul_mat_q_case<GGML_TYPE_IQ1_S_R4>(ctx, args, stream); + break; case GGML_TYPE_IQ4_XS: mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream); break; @@ -150,6 +153,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_KS: @@ -174,6 +178,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { if (int8_mma_available(cc)) { return true; } + if (type == GGML_TYPE_IQ1_S_R4) { + return false; + } if (cc < MIN_CC_DP4A) { return false; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 7a51c514..c98fa561 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -79,6 +79,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ3_S: return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_S_R4: return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: @@ -186,6 +187,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ3_XXS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ3_S : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ1_S : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S_R4: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0; @@ -231,6 +233,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ3_XXS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ3_S : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ1_S : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S_R4: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0; @@ -318,6 +321,74 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin } } +template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kbx = threadIdx.x / 4; + const int kqsx = threadIdx.x % 4; + + int32_t grid32[2]; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + const int i4 = i/4; + const int ir = i%4; + + const block_iq1_s_r4 * bxi = (const block_iq1_s_r4 *)(x + 4*i4*stride + 4*sizeof(half)) + kbx0 + kbx; + + grid32[0] = iq1s_grid_gpu[bxi->qs[4*kqsx+ir] | (((bxi->qh[ir] >> 3*kqsx) & 7) << 8)]; + grid32[1] = ((grid32[0] >> 4) & 0x0f0f0f0f) << 3; + grid32[0] = (grid32[0] & 0x0f0f0f0f) << 3; + const int shift = bxi->qh[ir] & 0x8000 ? 0x09090909 : 0x07070707; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kbx + 2*kqsx + 0] = __vsubss4(grid32[0], shift); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kbx + 2*kqsx + 1] = __vsubss4(grid32[1], shift); +#else + // TODO + //x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; +#endif // INT8_MMA_AVAILABLE + } + + const int blocks_per_tile_x_row = WARP_SIZE / 4; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + const int i4 = i/4; + const int ir = i%4; + + const half * dptr = (const half *)(x + 4*i4*stride); + const block_iq1_s_r4 * bxi = (const block_iq1_s_r4 *)(dptr + 4) + kbx0 + kbxd; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = 0.125f * __half2float(dptr[ir]) * (((bxi->qh[ir] >> 11) & 14) + 1); +#else + // TODO + //x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE + } +} + template <int mmq_x, int mmq_y, int nwarps> static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { @@ -3133,6 +3204,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> { }; template <int mmq_x, int mmq_y, int nwarps, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S_R4> { + static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s_r4<mmq_y, nwarps, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; +}; + +template <int mmq_x, int mmq_y, int nwarps, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> { static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>; @@ -3656,6 +3735,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 6b5d37fa..cc00d278 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -560,6 +560,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ5_KS_R4: mul_mat_vec_iq5_ks_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; + case GGML_TYPE_IQ1_S_R4: + mul_mat_vec_iq1_s_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -679,6 +682,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ5_KS_R4: + case GGML_TYPE_IQ1_S_R4: return true; default: return false; diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_r4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_r4.cu new file mode 100644 index 00000000..7bc5cf1a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_r4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); |