summaryrefslogtreecommitdiff
path: root/ggml-cuda/template-instances/generate_cu_files.py
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2024-06-05 16:53:00 +0200
committerGitHub <noreply@github.com>2024-06-05 16:53:00 +0200
commit7d1a378b8fb266782d9248538a661405aad80768 (patch)
tree7ce459a4c5a85e75f75825772124aedc3bb54b7f /ggml-cuda/template-instances/generate_cu_files.py
parent2b3389677a833cee0880226533a1768b1a9508d2 (diff)
CUDA: refactor mmq, dmmv, mmvq (#7716)
* CUDA: refactor mmq, dmmv, mmvq * fix out-of-bounds write * struct for qk, qr, qi * fix cmake build * mmq_type_traits
Diffstat (limited to 'ggml-cuda/template-instances/generate_cu_files.py')
-rwxr-xr-xggml-cuda/template-instances/generate_cu_files.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/ggml-cuda/template-instances/generate_cu_files.py b/ggml-cuda/template-instances/generate_cu_files.py
index ee5b460e..ea58d096 100755
--- a/ggml-cuda/template-instances/generate_cu_files.py
+++ b/ggml-cuda/template-instances/generate_cu_files.py
@@ -20,6 +20,18 @@ SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_
SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
+TYPES_MMQ = [
+ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
+ "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K"
+]
+
+SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE({type});
+"""
+
def get_short_name(long_quant_name):
return long_quant_name.replace("GGML_TYPE_", "").lower()
@@ -57,3 +69,7 @@ for kq_acc_t in ["half", "float"]:
if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance
continue
f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size))
+
+for type in TYPES_MMQ:
+ with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
+ f.write(SOURCE_MMQ.format(type=type))