summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-11-21 07:12:11 +0100
committerGitHub <noreply@github.com>2024-11-21 07:12:11 +0100
commit4d2fbde0cbbfc98200b59ed6fe5b32628a70c055 (patch)
treec160e53e2f6fa70a5d4de141dab7e9e6ecd46059
parent52874c5d21819bd63cc4c500f2fb1be435d16b5e (diff)
MMQ for Q6_0 (#115)
* MMQ for Q6_0 * Add Q6_0 MMQ to template generator --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-cuda/dmmv.cu2
-rw-r--r--ggml/src/ggml-cuda/mmq.cu4
-rw-r--r--ggml/src/ggml-cuda/mmq.cuh76
-rwxr-xr-xggml/src/ggml-cuda/template-instances/generate_cu_files.py2
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0.cu5
5 files changed, 87 insertions, 2 deletions
diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu
index 96a5adef..12738240 100644
--- a/ggml/src/ggml-cuda/dmmv.cu
+++ b/ggml/src/ggml-cuda/dmmv.cu
@@ -621,7 +621,7 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
src1_dfloat = src1_dfloat_a.alloc(ne00);
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
- to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
+ to_fp16_cuda(src1_ddf_i, src1_dfloat, 1, ne00, stream);
}
#else
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 09d3e9c7..f9fc2438 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -39,6 +39,9 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_Q5_1:
mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
break;
+ case GGML_TYPE_Q6_0:
+ mul_mat_q_case<GGML_TYPE_Q6_0>(ctx, args, stream);
+ break;
case GGML_TYPE_Q8_0:
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
break;
@@ -103,6 +106,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index e8a95744..416b4336 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -53,6 +53,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q5_1:
return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_Q6_0:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q8_0:
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q2_K:
@@ -155,6 +157,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
+ type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
@@ -189,6 +192,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
+ type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
@@ -556,6 +560,69 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
}
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_0(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_0, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = threadIdx.x / QI6_0;
+ const int kqsx = threadIdx.x % QI6_0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbx;
+
+ const int ql = get_int_b2(bxi->qs, kqsx);
+ const int qh = get_int_b2(bxi->qh, kqsx%2) >> 4*(kqsx/2);
+
+ int qs0 = ((ql >> 0) & 0x0F0F0F0F) | ((qh << 4) & 0x30303030);
+ int qs1 = ((ql >> 4) & 0x0F0F0F0F) | ((qh << 2) & 0x30303030);
+ qs0 = __vsubss4(qs0, 0x20202020); // subtract 32
+ qs1 = __vsubss4(qs1, 0x20202020); // subtract 32
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + 0] = qs0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + 0] = qs0;
+ x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI6_0;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_0) {
+ int i = i0 + threadIdx.y * QI6_0 + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#else
+ x_df[i*(WARP_SIZE/QI6_0) + i/QI6_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
@@ -2380,6 +2447,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_0> {
+ static constexpr int vdr = VDR_Q6_0_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_0<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
@@ -2910,6 +2985,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q6_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
index 4f7489d5..3037aa96 100755
--- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
+++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
@@ -24,7 +24,7 @@ TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
- "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
+ "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_Q6_0"
]
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0.cu
new file mode 100644
index 00000000..8a728e6c
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q6_0);