diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-27 17:40:59 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-27 17:40:59 +0300 |
commit | c7e99c88a2de7489ba2a1539b1a9025912010b70 (patch) | |
tree | 9976409b1e8fac1fc7486f2c5da05a33b8e229b5 /ggml/src/ggml-cuda | |
parent | bd99ed7d0afd2b12c0f5ff5c17b58486396dfe7e (diff) |
Faster Gemma2 (#27)
* soft_cap_max: initial CPU version of fused softcap + soft_max
With this vanilla CPU implementation I'm already getting a ~3% speedup
for Gemma-2-9b and a prompt of 8192 tokens.
* soft_cap_max: WIP - something is wrong with CUDA
* soft_cap_max: looks good on CPU and CUDA
* Add softcap to flash attention
Just CPU and CUDA for now (but, as we know, flash attention
on the CPU is useless in llama.cpp).
On CUDA this improves PP performance quite a bit, especially for
long contexts. E.g., for PP-16384, I now get 3777 t/s.
Without this change, one cannot use FA, and one gets 2300 t/s
(after fusing softcap and softmax), or 2000 t/s without the
fused softcap+softmax.
In comparison, mainline llama.cpp has PP-16384 = 1549 t/s before
PR-8542 (where Johannes Gaessler has also added softcap to FA),
and PP-16384 = 3097 t/s after this PR.
* soft_cap_max: Metal
* Flash attention with softcap: Metal
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r-- | ggml/src/ggml-cuda/fattn-common.cuh | 9 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/fattn-tile-f16.cu | 46 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/fattn-tile-f32.cu | 40 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/fattn-vec-f16.cuh | 59 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/fattn-vec-f32.cuh | 56 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 38 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/fattn.cu | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/softmax.cu | 67 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/softmax.cuh | 2 |
9 files changed, 259 insertions, 62 deletions
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 950fd93d..e4021764 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -21,6 +21,7 @@ typedef void (* fattn_kernel_t)( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -659,9 +660,15 @@ void launch_fattn( float scale = 1.0f; float max_bias = 0.0f; + float softcap = 0.0f; memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + memcpy(&softcap, (float *) KQV->op_params + 2, sizeof(float)); + + if (softcap != 0.0f) { + scale /= softcap; + } const uint32_t n_head = Q->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); @@ -675,7 +682,7 @@ void launch_fattn( V_data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, + scale, max_bias, m0, m1, softcap, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 1b2fd500..d1bbf01f 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 -template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size +template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f16( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; - half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + half sum; + if (use_softcap) { + const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + sum = softcap * tanhf(tmp.x + tmp.y); + } else { + sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum); @@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template <int cols_per_block, int parallel_blocks> +template <int cols_per_block, int parallel_blocks, bool use_softcap> void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>; launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>; launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -296,24 +309,39 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); + float softcap; + memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst); + } else { + launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, true>(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst); + } else { + launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, true>(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst); + } else { + launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index f3e68dbf..25908d7a 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F32 32 -template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size +template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f32( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; + if (use_softcap) { + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); @@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32( } } -template <int cols_per_block, int parallel_blocks> +template <int cols_per_block, int parallel_blocks, bool use_softcap> void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>; launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>; launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -292,21 +303,36 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; + float softcap; + memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst); + } else { + launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst); + } else { + launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst); + } else { + launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 02a4ad07..cf628dd5 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size +template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K); @@ -190,6 +197,9 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + if (use_softcap) { + sum = softcap*tanhf(sum); + } sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { @@ -286,10 +296,10 @@ static __global__ void flash_attn_vec_ext_f16( #endif // FP16_AVAILABLE } -template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V> +template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -297,48 +307,71 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, template <int D, ggml_type type_K, ggml_type type_V> void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * KQV = dst; - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float softcap; + memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + } } #define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 11a5e355..1aa88272 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size +template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K); @@ -180,6 +187,9 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + if (use_softcap) { + sum = softcap*tanhf(sum); + } sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); @@ -267,10 +277,10 @@ static __global__ void flash_attn_vec_ext_f32( } } -template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V> +template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -278,44 +288,68 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, template <int D, ggml_type type_K, ggml_type type_V> void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float softcap; + memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst); + } } #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index ae232224..efe78a2f 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -6,7 +6,7 @@ #endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t> +template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_softcap> #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -21,6 +21,7 @@ static __global__ void flash_attn_ext_f16( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { #ifdef FP16_MMA_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. @@ -84,6 +91,7 @@ static __global__ void flash_attn_ext_f16( const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); + const half2 softcap_2 = make_half2(softcap, softcap); frag_b Q_b[D/16][ncols/frag_n]; @@ -194,6 +202,9 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + if (use_softcap) { + KQ_f_tmp[k0/WARP_SIZE] = softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + } } float KQ_max_new = KQ_max_f[j0/nwarps]; @@ -237,6 +248,16 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + if (use_softcap) { + // There is no dedicated tangens hyperbolicus function for half2. + // Yes, and the code below can produce NaNs on overflow + KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + + KQ2_tmp[k0/WARP_SIZE] *= softcap_2; + } + } half2 KQ_max_new = KQ_max_h2[j0/nwarps]; @@ -435,20 +456,29 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + float softcap; + memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float)); + if (4*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 4; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>; + fattn_kernel_t fattn_kernel = softcap == 0.0f ? + flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> : + flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>; launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 2; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>; + fattn_kernel_t fattn_kernel = softcap == 0.0f ? + flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> : + flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>; launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>; + fattn_kernel_t fattn_kernel = softcap == 0.0f ? + flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> : + flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>; launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 29f608b0..f87f33b3 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; if (precision != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { @@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index c24abae1..6f3056e6 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -12,7 +12,7 @@ __device__ float __forceinline__ t2f32<half>(half val) { } template <bool vals_smem, int ncols_template, int block_size_template, typename T> -static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -44,7 +44,8 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst const int64_t ix = (int64_t)rowx*ncols + col; const int64_t iy = (int64_t)rowy*ncols + col; - const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); + const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + (mask ? slope*t2f32(mask[iy]) : 0.0f) : + scale*x[ix] + (mask ? slope*t2f32(mask[iy]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -116,7 +117,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst } template<typename T> -static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -134,36 +135,36 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 64: - soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 128: - soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 256: - soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 512: - soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 1024: - soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 2048: - soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 4096: - soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; default: - soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); } } @@ -197,10 +198,46 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (use_f16) { const half * src1_dd = (const half *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); } else { const float * src1_dd = (const float *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); + } +} + +void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + const float * src0_d = (const float *)src0->data; + const void * src1_d = src1 ? (const void *)src1->data : nullptr; + + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + float params[4]; + memcpy(params, dst->op_params, sizeof(params)); + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + //printf("%s: %g, %g, %g, %g, %p, %d\n", __func__, params[0], params[1], params[2], params[3], (const void *)src1, use_f16); + + if (use_f16) { + const half * src1_dd = (const half *)src1_d; + + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); + } else { + const float * src1_dd = (const float *)src1_d; + + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); } } diff --git a/ggml/src/ggml-cuda/softmax.cuh b/ggml/src/ggml-cuda/softmax.cuh index 4ef4ff86..49a83dfa 100644 --- a/ggml/src/ggml-cuda/softmax.cuh +++ b/ggml/src/ggml-cuda/softmax.cuh @@ -3,3 +3,5 @@ #define CUDA_SOFT_MAX_BLOCK_SIZE 1024 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |