summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/fattn-common.cuh9
-rw-r--r--ggml/src/ggml-cuda/fattn-tile-f16.cu46
-rw-r--r--ggml/src/ggml-cuda/fattn-tile-f32.cu40
-rw-r--r--ggml/src/ggml-cuda/fattn-vec-f16.cuh59
-rw-r--r--ggml/src/ggml-cuda/fattn-vec-f32.cuh56
-rw-r--r--ggml/src/ggml-cuda/fattn-wmma-f16.cuh38
-rw-r--r--ggml/src/ggml-cuda/fattn.cu4
-rw-r--r--ggml/src/ggml-cuda/softmax.cu67
-rw-r--r--ggml/src/ggml-cuda/softmax.cuh2
9 files changed, 259 insertions, 62 deletions
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
index 950fd93d..e4021764 100644
--- a/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -21,6 +21,7 @@ typedef void (* fattn_kernel_t)(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -659,9 +660,15 @@ void launch_fattn(
float scale = 1.0f;
float max_bias = 0.0f;
+ float softcap = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+ memcpy(&softcap, (float *) KQV->op_params + 2, sizeof(float));
+
+ if (softcap != 0.0f) {
+ scale /= softcap;
+ }
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
@@ -675,7 +682,7 @@ void launch_fattn(
V_data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
- scale, max_bias, m0, m1, n_head_log2,
+ scale, max_bias, m0, m1, softcap, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu
index 1b2fd500..d1bbf01f 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu
@@ -4,7 +4,7 @@
#define FATTN_KQ_STRIDE_TILE_F16 64
-template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
+template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f16(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
+ // Skip unused kernel variants for faster compilation:
+ if (use_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
- half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ half sum;
+ if (use_softcap) {
+ const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ sum = softcap * tanhf(tmp.x + tmp.y);
+ } else {
+ sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ }
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
@@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
#endif // FP16_AVAILABLE
}
-template <int cols_per_block, int parallel_blocks>
+template <int cols_per_block, int parallel_blocks, bool use_softcap>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
default: {
@@ -296,24 +309,39 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+ float softcap;
+ memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ }
}
diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu
index f3e68dbf..25908d7a 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f32.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu
@@ -4,7 +4,7 @@
#define FATTN_KQ_STRIDE_TILE_F32 32
-template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
+template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f32(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
+ // Skip unused kernel variants for faster compilation:
+ if (use_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32(
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
+ if (use_softcap) {
+ sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ }
+
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
@@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32(
}
}
-template <int cols_per_block, int parallel_blocks>
+template <int cols_per_block, int parallel_blocks, bool use_softcap>
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
default: {
@@ -292,21 +303,36 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
+ float softcap;
+ memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float));
+
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
+ }
}
diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
index 02a4ad07..cf628dd5 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -1,7 +1,7 @@
#include "common.cuh"
#include "fattn-common.cuh"
-template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
+ // Skip unused kernel variants for faster compilation:
+ if (use_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
@@ -190,6 +197,9 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
+ if (use_softcap) {
+ sum = softcap*tanhf(sum);
+ }
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
if (ncols == 1) {
@@ -286,10 +296,10 @@ static __global__ void flash_attn_vec_ext_f16(
#endif // FP16_AVAILABLE
}
-template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
- fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
@@ -297,48 +307,71 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- ggml_tensor * KQV = dst;
- ggml_tensor * Q = dst->src[0];
- ggml_tensor * K = dst->src[1];
- ggml_tensor * V = dst->src[2];
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
+ float softcap;
+ memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
}
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
index 11a5e355..1aa88272 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
@@ -1,7 +1,7 @@
#include "common.cuh"
#include "fattn-common.cuh"
-template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
+ // Skip unused kernel variants for faster compilation:
+ if (use_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
@@ -180,6 +187,9 @@ static __global__ void flash_attn_vec_ext_f32(
for (int j = 0; j < ncols; ++j) {
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
+ if (use_softcap) {
+ sum = softcap*tanhf(sum);
+ }
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
@@ -267,10 +277,10 @@ static __global__ void flash_attn_vec_ext_f32(
}
}
-template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
- fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
@@ -278,44 +288,68 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- ggml_tensor * Q = dst->src[0];
- ggml_tensor * K = dst->src[1];
- ggml_tensor * V = dst->src[2];
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
+ float softcap;
+ memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
}
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
index ae232224..efe78a2f 100644
--- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
@@ -6,7 +6,7 @@
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
-template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
+template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_softcap>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -21,6 +21,7 @@ static __global__ void flash_attn_ext_f16(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_MMA_AVAILABLE
+ // Skip unused kernel variants for faster compilation:
+ if (use_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
@@ -84,6 +91,7 @@ static __global__ void flash_attn_ext_f16(
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
+ const half2 softcap_2 = make_half2(softcap, softcap);
frag_b Q_b[D/16][ncols/frag_n];
@@ -194,6 +202,9 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+ if (use_softcap) {
+ KQ_f_tmp[k0/WARP_SIZE] = softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+ }
}
float KQ_max_new = KQ_max_f[j0/nwarps];
@@ -237,6 +248,16 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+ if (use_softcap) {
+ // There is no dedicated tangens hyperbolicus function for half2.
+ // Yes, and the code below can produce NaNs on overflow
+ KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
+ KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
+ /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+
+ KQ2_tmp[k0/WARP_SIZE] *= softcap_2;
+ }
+
}
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
@@ -435,20 +456,29 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+ float softcap;
+ memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float));
+
if (4*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 4;
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ fattn_kernel_t fattn_kernel = softcap == 0.0f ?
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 2;
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ fattn_kernel_t fattn_kernel = softcap == 0.0f ?
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
constexpr int parallel_blocks = 1;
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ fattn_kernel_t fattn_kernel = softcap == 0.0f ?
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
}
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 29f608b0..f87f33b3 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
if (precision != GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
@@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD) {
diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu
index c24abae1..6f3056e6 100644
--- a/ggml/src/ggml-cuda/softmax.cu
+++ b/ggml/src/ggml-cuda/softmax.cu
@@ -12,7 +12,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
}
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
-static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
@@ -44,7 +44,8 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;
- const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
+ const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + (mask ? slope*t2f32(mask[iy]) : 0.0f) :
+ scale*x[ix] + (mask ? slope*t2f32(mask[iy]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
@@ -116,7 +117,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
}
template<typename T>
-static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
@@ -134,36 +135,36 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) {
case 32:
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 64:
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 128:
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 256:
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 512:
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 1024:
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 2048:
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 4096:
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
default:
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(float);
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
}
}
@@ -197,10 +198,46 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (use_f16) {
const half * src1_dd = (const half *)src1_d;
- soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
} else {
const float * src1_dd = (const float *)src1_d;
- soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
+ }
+}
+
+void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ const float * src0_d = (const float *)src0->data;
+ const void * src1_d = src1 ? (const void *)src1->data : nullptr;
+
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src0->ne[1];
+
+ float params[4];
+ memcpy(params, dst->op_params, sizeof(params));
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+ //printf("%s: %g, %g, %g, %g, %p, %d\n", __func__, params[0], params[1], params[2], params[3], (const void *)src1, use_f16);
+
+ if (use_f16) {
+ const half * src1_dd = (const half *)src1_d;
+
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
+ } else {
+ const float * src1_dd = (const float *)src1_d;
+
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
}
}
diff --git a/ggml/src/ggml-cuda/softmax.cuh b/ggml/src/ggml-cuda/softmax.cuh
index 4ef4ff86..49a83dfa 100644
--- a/ggml/src/ggml-cuda/softmax.cuh
+++ b/ggml/src/ggml-cuda/softmax.cuh
@@ -3,3 +3,5 @@
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);