diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-05-04 12:45:00 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-04 12:45:00 +0300 |
commit | f7c9a0f036951fecab32e056df954ebc54f8688f (patch) | |
tree | 277a7c5ee63fda3841488e38a1dda9d2a43e0094 | |
parent | 13281282986fb6783d0d7d64b3610bfb7085e749 (diff) |
CUDA: MMQ for IQ4_KS (#374)
* WIP
* WIP: still getting illegal memory access
* CUDA: MMQ for iq4_ks now works
~25% faster than dequantize+cuBLAS, ~10% slower than Q4_0 MMQ.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cu | 8 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmq.cuh | 148 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/vecdotq.cuh | 12 |
4 files changed, 133 insertions, 40 deletions
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 3b959182..67897a83 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -14,6 +14,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; + const int64_t nb01 = src0->nb[1]; const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; @@ -22,7 +23,6 @@ void ggml_cuda_op_mul_mat_q( const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; - const int64_t stride00 = ne00 / ggml_blck_size(src0->type); int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; @@ -31,7 +31,7 @@ void ggml_cuda_op_mul_mat_q( // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst}; + const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, nb01, src1_padded_row_size, src1_ncols, ne11, nrows_dst}; switch (src0->type) { case GGML_TYPE_Q4_0: @@ -91,6 +91,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_NL: mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream); break; + case GGML_TYPE_IQ4_KS: + mul_mat_q_case<GGML_TYPE_IQ4_KS>(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -128,6 +131,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_KS: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 753848d7..148697e2 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -82,6 +82,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_KS: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -179,6 +180,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ1_S : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0; default : return tile_x_sizes{0, 0, 0}; } } @@ -216,6 +218,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ1_S : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0; default : return 0; } } @@ -261,7 +264,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); #ifdef INT8_MMA_AVAILABLE @@ -283,7 +286,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; + const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -356,7 +359,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); #ifdef INT8_MMA_AVAILABLE @@ -378,7 +381,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; + const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; @@ -451,7 +454,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; + const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbx; const int ql = get_int_b2(bxi->qs, kqsx); const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0)); @@ -490,7 +493,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; + const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -523,7 +526,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; + const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbx; const int ql = get_int_b4(bxi->qs, kqsx); const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1)); @@ -560,7 +563,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; + const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; @@ -593,7 +596,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbx; + const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbx; const int ql = get_int_b2(bxi->qs, kqsx); const int qh = get_int_b2(bxi->qh, kqsx%2) >> 4*(kqsx/2); @@ -623,7 +626,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbxd; + const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -656,7 +659,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; + const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbx; #ifdef INT8_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); @@ -678,7 +681,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; + const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -1044,7 +1047,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; + const block_q2_K * bxi = (const block_q2_K *)(x + i*stride) + kbx0; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); @@ -1275,7 +1278,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); @@ -1305,7 +1308,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; const int ksc = threadIdx.x % (WARP_SIZE/8); @@ -1341,7 +1344,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; x_df[i] = bxi->d; } @@ -1412,7 +1415,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; const int qs0 = get_int_b4(bxi->qs, threadIdx.x); #ifdef INT8_MMA_AVAILABLE @@ -1433,7 +1436,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/16); @@ -1462,7 +1465,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; x_dm[i] = bxi->dm; } @@ -1475,7 +1478,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8); + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0 + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8); const int * scales = (const int *) bxi->scales; @@ -1541,7 +1544,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int ky = QR5_K*threadIdx.x; const int ql = get_int_b4(bxi->qs, threadIdx.x); @@ -1574,7 +1577,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/16); @@ -1603,7 +1606,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; x_dm[i] = bxi->dm; } @@ -1616,7 +1619,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; @@ -1683,7 +1686,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0; const int ql = get_int_b2(bxi->ql, threadIdx.x); const int ql0 = (ql >> 0) & 0x0F0F0F0F; @@ -1716,7 +1719,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; @@ -1733,7 +1736,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0 + (threadIdx.x % (WARP_SIZE/8)) / 4; #ifdef INT8_MMA_AVAILABLE x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); @@ -1908,7 +1911,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx; + const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbx; const int aux_q4 = get_int_b2(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); @@ -1933,7 +1936,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; + const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); @@ -1965,7 +1968,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; + const block_iq2_xxs * bxi = (const block_iq2_xxs *)(x + i*stride) + kbx0; const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); const uint8_t * aux8 = (const uint8_t *) &q2; @@ -2023,7 +2026,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; + const block_iq2_xs * bxi = (const block_iq2_xs *)(x + i*stride) + kbx0; const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint16_t * q2 = (const uint16_t *) &q2_packed; @@ -2079,7 +2082,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; + const block_iq2_s * bxi = (const block_iq2_s *)(x + i*stride) + kbx0; const int qs_packed = get_int_b2(bxi->qs, kqsx); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2142,7 +2145,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; + const block_iq3_xxs * bxi = (const block_iq3_xxs *)(x + i*stride) + kbx0; const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint8_t * q3 = (const uint8_t *) &q3_packed; @@ -2198,7 +2201,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; + const block_iq3_s * bxi = (const block_iq3_s *)(x + i*stride) + kbx0; const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2261,7 +2264,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; + const block_iq1_s * bxi = (const block_iq1_s *)(x + i*stride) + kbx0; const int qs_packed = get_int_b2(bxi->qs, kqsx); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2318,7 +2321,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx; + const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx; const int aux_q4 = get_int_b4(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); @@ -2340,7 +2343,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0; const float d = __half2float(bxi->d); @@ -2355,6 +2358,64 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin } } +template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kbx = 0; // threadIdx.x / QI4_XS + const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx; + + auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4); + const int aux_q4 = get_int_b4(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, values); + const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0; + const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls; +#endif // INT8_MMA_AVAILABLE + } +} + template<int mmq_x, int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void mmq_write_back_dp4a( const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { @@ -2576,6 +2637,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; }; +template <int mmq_x, int mmq_y, int nwarps, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS> { + static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks<mmq_y, nwarps, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; +}; + template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup> static __device__ void mul_mat_q_process_tile( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, @@ -2608,7 +2677,7 @@ static __device__ void mul_mat_q_process_tile( const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { - load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); + load_tiles(x + stride01*it*mmq_y, tile_x, kb0, tile_x_max_i, stride01); { const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); @@ -2889,6 +2958,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>> (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); + } else { constexpr bool need_check = true; @@ -2897,6 +2967,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>> (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); + } } @@ -3010,6 +3081,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu new file mode 100644 index 00000000..940c2da8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index e9af29b9..cae5e04f 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1131,6 +1131,18 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) { return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); } +static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) { + const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; + const int8_t * q0_8 = (const int8_t *) &q0_32; + const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]); + + const int q1_32 = (q4 >> 4) & 0x0F0F0F0F; + const int8_t * q1_8 = (const int8_t *) &q1_32; + const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]); + + return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); +} + #define VDR_IQ4_NL_Q8_1_MMVQ 2 #define VDR_IQ4_NL_Q8_1_MMQ 4 |