summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml-cuda.cu5
-rw-r--r--ggml/src/ggml-cuda/fattn.cu4
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-q8_0-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-q8_0-q8_0.cu5
4 files changed, 19 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 63e16d52..1f62b882 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -3578,6 +3578,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (op->src[0]->ne[0] == 128) {
return true;
}
+ if (op->src[1]->ne[0] == 256 && op->src[2]->ne[0] == 256 &&
+ (op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_Q8_0) &&
+ (op->src[2]->type == GGML_TYPE_F16 || op->src[2]->type == GGML_TYPE_Q8_0)) {
+ return true;
+ }
if (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) {
return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) ||
(op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 8c3f4000..55116058 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -248,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0)
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
@@ -265,6 +266,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0)
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
@@ -347,6 +349,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0)
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
@@ -358,6 +361,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0)
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-q8_0-q8_0.cu
new file mode 100644
index 00000000..f257f5d8
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-q8_0-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-q8_0-q8_0.cu
new file mode 100644
index 00000000..a0f03f49
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-q8_0-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);