summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda/fattn-common.cuh
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda/fattn-common.cuh')
-rw-r--r--ggml/src/ggml-cuda/fattn-common.cuh108
1 files changed, 58 insertions, 50 deletions
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
index 0a664dbd..a46f03e5 100644
--- a/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -52,7 +52,7 @@ typedef half (*vec_dot_KQ_f16_t)(
typedef float (*vec_dot_KQ_f32_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
-template<typename T, int D>
+template<typename T, int Dk>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -62,7 +62,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
@@ -92,7 +92,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
return sum;
}
-template<typename T, int D>
+template<typename T, int Dk>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -102,7 +102,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
@@ -142,7 +142,7 @@ static __device__ __forceinline__ int get_one_int_from_table_16(const int & q4)
return *((const int *) &val0_8);
}
-template<typename T, int D>
+template<typename T, int Dk>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -152,7 +152,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl(
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
@@ -179,7 +179,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl(
return sum;
}
-template<typename T, int D>
+template<typename T, int Dk>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -189,7 +189,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
@@ -226,7 +226,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
return sum;
}
-template<typename T, int D>
+template<typename T, int Dk>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
@@ -277,7 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
return sum;
}
-template<typename T, int D>
+template<typename T, int Dk>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -287,7 +287,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0(
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
@@ -320,7 +320,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0(
return sum;
}
-template <typename T, int D>
+template <typename T, int Dk>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -330,7 +330,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_0;
@@ -353,7 +353,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
return sum;
}
-template <typename T, int D>
+template <typename T, int Dk>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
@@ -368,7 +368,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
half2 sum2 = make_half2(0.0f, 0.0f);
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const half2 K_ik = K_h2[k_KQ];
@@ -384,7 +384,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
float sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const half2 K_ik = K_h2[k_KQ];
@@ -603,29 +603,29 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
return x[i];
}
-template <int D>
+template <int Dk>
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
- type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, D> :
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
- type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<half, D> :
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, Dk> :
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, Dk> :
+ type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, Dk> :
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, Dk> :
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, Dk> :
+ type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<half, Dk> :
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, Dk> :
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, Dk> :
nullptr;
}
-template <int D>
+template <int Dk>
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
- type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float, D> :
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
- type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<float, D> :
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, Dk> :
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, Dk> :
+ type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float, Dk> :
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, Dk> :
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, Dk> :
+ type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<float, Dk> :
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, Dk> :
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, Dk> :
nullptr;
}
@@ -653,20 +653,20 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
nullptr;
}
-template<int D, int parallel_blocks> // D == head size
+template<int Dv, int parallel_blocks> // Dv == V head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(D, 1)
+__launch_bounds__(Dv, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_combine_results(
const float * __restrict__ VKQ_parts,
const float2 * __restrict__ VKQ_meta,
float * __restrict__ dst) {
- VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
- VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
- dst += D * gridDim.y*blockIdx.x;
+ VKQ_parts += parallel_blocks*Dv * gridDim.y*blockIdx.x;
+ VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
+ dst += Dv * gridDim.y*blockIdx.x;
const int tid = threadIdx.x;
- __builtin_assume(tid < D);
+ __builtin_assume(tid < Dv);
__shared__ float2 meta[parallel_blocks];
if (tid < 2*parallel_blocks) {
@@ -690,20 +690,20 @@ static __global__ void flash_attn_combine_results(
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*Dv + blockIdx.y*Dv + tid];
VKQ_denominator += KQ_max_scale * meta[l].y;
}
- dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
+ dst[blockIdx.y*Dv + tid] = VKQ_numerator / VKQ_denominator;
}
-static void on_no_fattn_vec_case(const int D) {
- if (D == 64) {
+static void on_no_fattn_vec_case(const int Dk, const int Dv) {
+ if (Dk == 64 && Dv == 64) {
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
fprintf(stderr, "By default only f16 KV cache is supported.\n");
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
GGML_ABORT("fatal error");
- } else if (D == 128) {
+ } else if (Dk == 128 && Dv == 128) {
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
fprintf(stderr, "Supported combinations:\n");
fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n");
@@ -715,14 +715,22 @@ static void on_no_fattn_vec_case(const int D) {
fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n");
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n");
GGML_ABORT("fatal error");
+ }
+ else if (Dk == 192 && Dv == 128) {
+ fprintf(stderr, "Unsupported KV type combination for head_sizes 192 / 128\n");
+ // TODO: add what is supported
+ }
+ else if (Dk == 576 && Dv == 512) {
+ fprintf(stderr, "Unsupported KV type combination for head_sizes 576 / 512\n");
+ // TODO: add what is supported
} else {
- fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
+ fprintf(stderr, "Unsupported KV type combination for head_sizes %d, %d.\n", Dk, Dv);
fprintf(stderr, "Only f16 is supported.\n");
GGML_ABORT("fatal error");
}
}
-template <int D, int parallel_blocks>
+template <int Dk, int Dv, int parallel_blocks>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
@@ -838,11 +846,11 @@ void launch_fattn(
return;
}
- const dim3 block_dim_combine(D, 1, 1);
+ const dim3 block_dim_combine(Dv, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
- flash_attn_combine_results<D, parallel_blocks>
+ flash_attn_combine_results<Dv, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());