summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-10-22 11:34:49 +0200
committerGitHub <noreply@github.com>2024-10-22 11:34:49 +0200
commit462c6cd7b1b03843ab782e36c75da9bfea657c14 (patch)
treeb422b698b8373f7a8b29bb7b114a11bb9ba87dce
parentdbf951df1594a3dec36eca9ab81a0f7ba81b11cd (diff)
Enable q6_0 for flash attention (#101)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--Makefile2
-rw-r--r--ggml/src/CMakeLists.txt4
-rw-r--r--ggml/src/ggml-cuda/fattn-common.cuh83
-rw-r--r--ggml/src/ggml-cuda/fattn.cu15
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q6_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q6_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q6_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q6_0.cu5
-rwxr-xr-xggml/src/ggml-cuda/template-instances/generate_cu_files.py2
11 files changed, 120 insertions, 16 deletions
diff --git a/Makefile b/Makefile
index ae636f50..a7a53b1a 100644
--- a/Makefile
+++ b/Makefile
@@ -600,6 +600,8 @@ else
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu))
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-iq4_nl.cu))
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*:iq4_nl-iq4_nl.cu))
+ OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*:q6_0-q5_0.cu))
+ OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*:q8_0-q6_0.cu))
endif # GGML_CUDA_FA_ALL_QUANTS
ifdef GGML_CUDA
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index eb6d457c..da1746c8 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -332,6 +332,10 @@ if (GGML_CUDA)
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*iq4_nl-iq4_nl.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q6_0-q5_0.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q6_0.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
endif()
list(APPEND GGML_CDEF_PUBLIC GGML_USE_CUDA)
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
index 1984c838..0a664dbd 100644
--- a/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -277,6 +277,49 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
return sum;
}
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_q6_0 * K_q6_0 = (const block_q6_0 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ T sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ const int k_KQ = k_KQ_0 + threadIdx.x;
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI6_0; // 0...3
+ const int shift = k_KQ & (QI8_1/2);
+
+ const int vh = (get_int_b2(K_q6_0[ib].qh, iqs4%2) >> (4*(iqs4/2) + shift/2)) & 0x03030303;
+ const int vl = (get_int_b2(K_q6_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+ const int v = vl | (vh << 4);
+
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
+
+ const half2 sum2 = __half2half2(K_q6_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+ sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(4.0f)) /* *32/QI8_1 == 4 */;
+ } else
+#endif // FP16_AVAILABLE
+ {
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+ sum += (T) (__half2float(K_q6_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (32/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+ }
+ }
+
+ return sum;
+}
+
template <typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -511,6 +554,30 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
}
template <typename T>
+static __device__ __forceinline__ T dequantize_1_q6_0(const void * __restrict__ vx, const int64_t i) {
+ const block_q6_0 * x = (const block_q6_0 *) vx;
+
+ const int64_t ib = i / QK6_0;
+ const int idq = i % QK6_0;
+ const int iqs = i % (QK6_0/2);
+ const int shift = idq / (QK6_0/2);
+ //const int shift = (i % QK6_0) / (QK6_0/2);
+
+ const T d = x[ib].d;
+ const int ql = x[ib].qs[iqs] >> 4*shift;
+ const int qh = x[ib].qh[idq%(QK6_0/4)] >> (4*((idq/(QK6_0/4))%2) + 2*shift);
+ const int q = ((ql & 0x0f) | ((qh & 0x03) << 4)) - 32;
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ return ((half) d)*((half) q);
+ }
+#endif // FP16_AVAILABLE
+
+ return ((float) d)*((float) q);
+}
+
+template <typename T>
static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
const block_q8_0 * x = (const block_q8_0 *) vx;
@@ -543,6 +610,7 @@ constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, D> :
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
+ type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<half, D> :
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
nullptr;
@@ -555,6 +623,7 @@ constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float, D> :
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
+ type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<float, D> :
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
nullptr;
@@ -565,6 +634,7 @@ constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
+ type_V == GGML_TYPE_Q6_0 ? dequantize_1_q6_0<half> :
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<half> :
type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
@@ -576,6 +646,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
+ type_V == GGML_TYPE_Q6_0 ? dequantize_1_q6_0<float> :
type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<float> :
type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
@@ -635,11 +706,13 @@ static void on_no_fattn_vec_case(const int D) {
} else if (D == 128) {
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
fprintf(stderr, "Supported combinations:\n");
- fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
- fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.50 BPV\n");
- fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.50 BPV\n");
- fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
- fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
+ fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n");
+ fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.5 BPV\n");
+ fprintf(stderr, " - K == q6_0, V == q5_0, 6.0 BPV\n");
+ fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.5 BPV\n");
+ fprintf(stderr, " - K == q8_0, V == q6_0, 7.5 BPV\n");
+ fprintf(stderr, " - K == q8_0, V == q8_0, 8.5 BPV\n");
+ fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n");
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n");
GGML_ABORT("fatal error");
} else {
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index c5540161..1dfb24f9 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -208,13 +208,11 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
- //FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
- //FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
- //FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
- //FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
- //FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
#else
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@@ -224,13 +222,10 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
- //FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
- //FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
- //FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
- //FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
- //FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
#endif // GGML_CUDA_FA_ALL_QUANTS
on_no_fattn_vec_case(Q->ne[0]);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q5_0.cu
new file mode 100644
index 00000000..d1ecb548
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q6_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q6_0.cu
new file mode 100644
index 00000000..e605e7a6
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q6_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q6_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q6_0.cu
new file mode 100644
index 00000000..80539daf
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q6_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q5_0.cu
new file mode 100644
index 00000000..78eca1c2
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q6_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q6_0.cu
new file mode 100644
index 00000000..fa16ddbc
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q6_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q6_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q6_0.cu
new file mode 100644
index 00000000..d25d482b
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q6_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0);
diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
index 1186112e..4f7489d5 100755
--- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
+++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
@@ -3,7 +3,7 @@
from glob import glob
import os
-TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_IQ4_NL", "GGML_TYPE_F16"]
+TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_IQ4_NL", "GGML_TYPE_Q6_0", "GGML_TYPE_F16"]
SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.