summaryrefslogtreecommitdiff
path: root/ggml-cuda/fattn-common.cuh
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2024-06-01 15:47:04 +0200
committerGitHub <noreply@github.com>2024-06-01 15:47:04 +0200
commit750f60c03e4d3f53fa51910551ce87a3d508d2d7 (patch)
tree6fdbcc6e7824b1ebd7586de57fff3c76636a86da /ggml-cuda/fattn-common.cuh
parent9b596417af11c9ac44fcae0fcfbc6f3665089083 (diff)
CUDA: fix Pascal FA, deq. KV to FP16 for batch > 8 (#7681)
Diffstat (limited to 'ggml-cuda/fattn-common.cuh')
-rw-r--r--ggml-cuda/fattn-common.cuh74
1 files changed, 59 insertions, 15 deletions
diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh
index 4bf03a49..c00f8606 100644
--- a/ggml-cuda/fattn-common.cuh
+++ b/ggml-cuda/fattn-common.cuh
@@ -1,6 +1,7 @@
#pragma once
#include "common.cuh"
+#include "convert.cuh"
#include "vecdotq.cuh"
#include <cstdint>
@@ -53,7 +54,7 @@ typedef float (*vec_dot_KQ_f32_t)(
template<typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ > MIN_CC_DP4A
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
GGML_UNUSED(Q_v);
@@ -95,13 +96,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ > MIN_CC_DP4A
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
template<typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ > MIN_CC_DP4A
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
GGML_UNUSED(Q_v);
@@ -147,13 +148,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ > MIN_CC_DP4A
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
template<typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ > MIN_CC_DP4A
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
GGML_UNUSED(Q_v);
@@ -202,13 +203,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ > MIN_CC_DP4A
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
template<typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ > MIN_CC_DP4A
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
GGML_UNUSED(Q_v);
@@ -261,13 +262,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ > MIN_CC_DP4A
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
template <typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ > MIN_CC_DP4A
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
GGML_UNUSED(Q_v);
@@ -302,7 +303,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ > MIN_CC_DP4A
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
template <typename T, int D>
@@ -620,7 +621,10 @@ static void on_no_fattn_vec_case(const int D) {
}
template <int D, int parallel_blocks>
-void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
+void launch_fattn(
+ ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
+ const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
+) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
@@ -641,9 +645,49 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
+ ggml_cuda_pool_alloc<half> K_f16(pool);
+ ggml_cuda_pool_alloc<half> V_f16(pool);
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+ char * K_data = (char *) K->data;
+ size_t nb11 = K->nb[1];
+ size_t nb12 = K->nb[2];
+ size_t nb13 = K->nb[3];
+
+ char * V_data = (char *) V->data;
+ size_t nb21 = V->nb[1];
+ size_t nb22 = V->nb[2];
+ size_t nb23 = V->nb[3];
+
+ if (need_f16_K && K->type != GGML_TYPE_F16) {
+ K_f16.alloc(ggml_nelements(K));
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
+ to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+ K_data = (char *) K_f16.ptr;
+
+ const size_t bs = ggml_blck_size(K->type);
+ const size_t ts = ggml_type_size(K->type);
+
+ nb11 = nb11*bs*sizeof(half)/ts;
+ nb12 = nb12*bs*sizeof(half)/ts;
+ nb13 = nb13*bs*sizeof(half)/ts;
+ }
+
+ if (need_f16_V && V->type != GGML_TYPE_F16) {
+ V_f16.alloc(ggml_nelements(V));
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+ to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+ V_data = (char *) V_f16.ptr;
+
+ const size_t bs = ggml_blck_size(V->type);
+ const size_t ts = ggml_type_size(V->type);
+
+ nb21 = nb21*bs*sizeof(half)/ts;
+ nb22 = nb22*bs*sizeof(half)/ts;
+ nb23 = nb23*bs*sizeof(half)/ts;
+ }
+
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
@@ -667,8 +711,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
(const char *) Q->data,
- (const char *) K->data,
- (const char *) V->data,
+ K_data,
+ V_data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
@@ -676,8 +720,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
- K->nb[1], K->nb[2], K->nb[3],
- V->nb[1], V->nb[2], V->nb[3],
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());