summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-10-21 12:16:54 +0200
committerGitHub <noreply@github.com>2024-10-21 12:16:54 +0200
commitdbf951df1594a3dec36eca9ab81a0f7ba81b11cd (patch)
treea1692718a4602927fe34dc313841aec582e7fe58
parentf2d315b46f7aacc7df4b86bd8acba387b30e11ca (diff)
Enable IQ4_NL for KV-cache in token generation using Flash Attention (#99)
* Enable IQ4_NL for V-cache in token generation * We don't need these * Update printour of allowed quantized KV-cache combinations * Add IQ4_NL + IQ4_NL to FA This is a better alternative than Q4_0 + Q4_0 for the VRAM poor. * Remove file added by mistake * Fix typo, which is not really a bug --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--Makefile6
-rw-r--r--ggml/src/CMakeLists.txt4
-rw-r--r--ggml/src/ggml-cuda/fattn-common.cuh132
-rw-r--r--ggml/src/ggml-cuda/fattn-vec-f16.cuh7
-rw-r--r--ggml/src/ggml-cuda/fattn.cu20
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-iq4_nl-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-iq4_nl.cu5
-rwxr-xr-xggml/src/ggml-cuda/template-instances/generate_cu_files.py2
22 files changed, 214 insertions, 37 deletions
diff --git a/Makefile b/Makefile
index 658b6e35..ae636f50 100644
--- a/Makefile
+++ b/Makefile
@@ -246,11 +246,11 @@ endif
# Compile flags
#
-# keep standard at C11 and C++11
+# keep standard at C11 and C++17
MK_CPPFLAGS = -Iggml/include -Iggml/src -Iinclude -Isrc -Icommon
MK_CFLAGS = -std=c11 -fPIC
MK_CXXFLAGS = -std=c++17 -fPIC
-MK_NVCCFLAGS = -std=c++11
+MK_NVCCFLAGS = -std=c++17
ifdef LLAMA_NO_CCACHE
GGML_NO_CCACHE := 1
@@ -598,6 +598,8 @@ else
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu))
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu))
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu))
+ OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-iq4_nl.cu))
+ OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*:iq4_nl-iq4_nl.cu))
endif # GGML_CUDA_FA_ALL_QUANTS
ifdef GGML_CUDA
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 600acf91..eb6d457c 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -328,6 +328,10 @@ if (GGML_CUDA)
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-iq4_nl.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*iq4_nl-iq4_nl.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
endif()
list(APPEND GGML_CDEF_PUBLIC GGML_USE_CUDA)
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
index 0bcd1ff7..1984c838 100644
--- a/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -136,6 +136,49 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
return sum;
}
+static __device__ __forceinline__ int get_one_int_from_table_16(const int & q4) {
+ const uint8_t * q0_8 = (const uint8_t *) &q4;
+ const char4 val0_8 = make_char4(kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
+ return *((const int *) &val0_8);
+}
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_iq4_nl * K_iq4_nl = (const block_iq4_nl *) K_c;
+ GGML_UNUSED(Q_v);
+
+ T sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ const int k_KQ = k_KQ_0 + threadIdx.x;
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI4_NL;
+ const int shift = k_KQ & (QI8_1/2);
+
+ const int v = get_one_int_from_table_16((get_int_b2(K_iq4_nl[ib].qs, iqs4) >> shift) & 0x0F0F0F0F);
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
+ sum += (T) (((half)sumi) * K_iq4_nl[ib].d * Q_ds[k_KQ_0/WARP_SIZE].x);
+ } else
+#endif // FP16_AVAILABLE
+ {
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
+ sum += (T) ((float)sumi * __half2float(K_iq4_nl[ib].d) * Q_ds[k_KQ_0/WARP_SIZE].x);
+ }
+ }
+
+ return sum;
+}
+
template<typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -378,6 +421,25 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
}
template <typename T>
+static __device__ __forceinline__ T dequantize_1_iq4_nl(const void * __restrict__ vx, const int64_t i) {
+ const block_iq4_nl * x = (const block_iq4_nl *) vx;
+
+ const int64_t ib = i / QK4_NL;
+ const int iqs = i % (QK4_NL/2);
+ const int shift = (i % QK4_NL) / (QK4_NL/2);
+
+#ifdef FP16_AVAILABLE
+ if constexpr (std::is_same<T, half>::value) {
+ return x[ib].d * ((half) kvalues_iq4nl[(x[ib].qs[iqs] >> 4*(shift)) & 0xf]);
+ } else {
+ return (float)x[ib].d * ((float) kvalues_iq4nl[(x[ib].qs[iqs] >> 4*(shift)) & 0xf]);
+ }
+#endif
+ T result = (float)x[ib].d * ((float) kvalues_iq4nl[(x[ib].qs[iqs] >> 4*(shift)) & 0xf]);
+ return result;
+}
+
+template <typename T>
static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
const block_q4_1 * x = (const block_q4_1 *) vx;
@@ -476,44 +538,48 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
template <int D>
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
- nullptr;
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
+ type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, D> :
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
+ nullptr;
}
template <int D>
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
- nullptr;
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
+ type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float, D> :
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
+ nullptr;
}
constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
- return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
- type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
- type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
- type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
- type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
- type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
- nullptr;
+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
+ type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<half> :
+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
+ nullptr;
}
constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
- return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
- type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
- type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
- type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
- type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
- type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
- nullptr;
+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
+ type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl<float> :
+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
+ nullptr;
}
template<int D, int parallel_blocks> // D == head size
@@ -569,10 +635,12 @@ static void on_no_fattn_vec_case(const int D) {
} else if (D == 128) {
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
fprintf(stderr, "Supported combinations:\n");
- fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
- fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
- fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
- fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
+ fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
+ fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.50 BPV\n");
+ fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.50 BPV\n");
+ fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
+ fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
+ fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n");
GGML_ABORT("fatal error");
} else {
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
index cf628dd5..7f14e78b 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -392,6 +392,13 @@ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+//extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
+//extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
+//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
+//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
+//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
+
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index f87f33b3..c5540161 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -152,7 +152,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
} \
static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- ggml_tensor * Q = dst->src[1];
+ ggml_tensor * Q = dst->src[0];
ggml_tensor * K = dst->src[1];
ggml_tensor * V = dst->src[2];
@@ -207,6 +207,14 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+
+ //FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
+ //FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
+ //FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
+ //FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
+ //FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
#else
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@@ -215,6 +223,14 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+
+ //FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
+ //FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
+ //FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
+ //FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
+ //FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
#endif // GGML_CUDA_FA_ALL_QUANTS
on_no_fattn_vec_case(Q->ne[0]);
@@ -227,7 +243,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
} \
static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- ggml_tensor * Q = dst->src[1];
+ ggml_tensor * Q = dst->src[0];
ggml_tensor * K = dst->src[1];
ggml_tensor * V = dst->src[2];
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-iq4_nl.cu
new file mode 100644
index 00000000..34bbc716
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-iq4_nl.cu
new file mode 100644
index 00000000..672a39d0
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-iq4_nl.cu
new file mode 100644
index 00000000..ad3fb05c
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-iq4_nl.cu
new file mode 100644
index 00000000..b5f60a62
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-iq4_nl.cu
new file mode 100644
index 00000000..19407254
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-iq4_nl.cu
new file mode 100644
index 00000000..b2269714
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-iq4_nl.cu
new file mode 100644
index 00000000..68345ecb
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-iq4_nl.cu
new file mode 100644
index 00000000..a13ad97e
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-iq4_nl.cu
new file mode 100644
index 00000000..afd75bee
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-iq4_nl-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-iq4_nl-iq4_nl.cu
new file mode 100644
index 00000000..286c9e20
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-iq4_nl-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-iq4_nl.cu
new file mode 100644
index 00000000..49c69d6b
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-iq4_nl.cu
new file mode 100644
index 00000000..af9bd9ed
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-iq4_nl.cu
new file mode 100644
index 00000000..a5eb1950
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-iq4_nl.cu
new file mode 100644
index 00000000..30a19366
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-iq4_nl.cu
new file mode 100644
index 00000000..59c402fb
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-iq4_nl.cu
new file mode 100644
index 00000000..932a16b7
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
index d7874e6e..1186112e 100755
--- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
+++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
@@ -3,7 +3,7 @@
from glob import glob
import os
-TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_F16"]
+TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_IQ4_NL", "GGML_TYPE_F16"]
SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.