summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/include/ggml.h23
-rw-r--r--ggml/src/ggml-cuda.cu4
-rw-r--r--ggml/src/ggml-cuda/fattn-common.cuh9
-rw-r--r--ggml/src/ggml-cuda/fattn-tile-f16.cu46
-rw-r--r--ggml/src/ggml-cuda/fattn-tile-f32.cu40
-rw-r--r--ggml/src/ggml-cuda/fattn-vec-f16.cuh59
-rw-r--r--ggml/src/ggml-cuda/fattn-vec-f32.cuh56
-rw-r--r--ggml/src/ggml-cuda/fattn-wmma-f16.cuh38
-rw-r--r--ggml/src/ggml-cuda/fattn.cu4
-rw-r--r--ggml/src/ggml-cuda/softmax.cu67
-rw-r--r--ggml/src/ggml-cuda/softmax.cuh2
-rw-r--r--ggml/src/ggml-metal.m88
-rw-r--r--ggml/src/ggml-metal.metal249
-rw-r--r--ggml/src/ggml.c251
-rw-r--r--src/llama.cpp71
-rw-r--r--tests/test-backend-ops.cpp22
16 files changed, 896 insertions, 133 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 17d3cb1a..1a4a516c 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -515,6 +515,7 @@ extern "C" {
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,
GGML_OP_SOFTCAP,
+ GGML_OP_SOFT_CAP_MAX,
GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_ATTN_BACK,
@@ -1237,6 +1238,25 @@ extern "C" {
float s_before,
float s_after);
+ GGML_API struct ggml_tensor * ggml_softcap_max(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias,
+ float s_before,
+ float s_after);
+
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_softcap_max_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias,
+ float s_before,
+ float s_after);
+
// b -> view(a,offset,nb1,nb2,3), return modified a
GGML_API struct ggml_tensor * ggml_set(
struct ggml_context * ctx,
@@ -1791,7 +1811,8 @@ extern "C" {
struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale,
- float max_bias);
+ float max_bias,
+ float softcap);
GGML_API void ggml_flash_attn_ext_set_prec(
struct ggml_tensor * a,
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 73ab0b73..056ca4a4 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2286,6 +2286,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SOFT_MAX:
ggml_cuda_op_soft_max(ctx, dst);
break;
+ case GGML_OP_SOFT_CAP_MAX:
+ ggml_cuda_op_soft_cap_max(ctx, dst);
+ break;
case GGML_OP_ROPE:
ggml_cuda_op_rope(ctx, dst);
break;
@@ -2876,6 +2879,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
+ case GGML_OP_SOFT_CAP_MAX:
return true;
case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]);
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
index 950fd93d..e4021764 100644
--- a/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -21,6 +21,7 @@ typedef void (* fattn_kernel_t)(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -659,9 +660,15 @@ void launch_fattn(
float scale = 1.0f;
float max_bias = 0.0f;
+ float softcap = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+ memcpy(&softcap, (float *) KQV->op_params + 2, sizeof(float));
+
+ if (softcap != 0.0f) {
+ scale /= softcap;
+ }
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
@@ -675,7 +682,7 @@ void launch_fattn(
V_data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
- scale, max_bias, m0, m1, n_head_log2,
+ scale, max_bias, m0, m1, softcap, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu
index 1b2fd500..d1bbf01f 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu
@@ -4,7 +4,7 @@
#define FATTN_KQ_STRIDE_TILE_F16 64
-template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
+template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f16(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
+ // Skip unused kernel variants for faster compilation:
+ if (use_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
- half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ half sum;
+ if (use_softcap) {
+ const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ sum = softcap * tanhf(tmp.x + tmp.y);
+ } else {
+ sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ }
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
@@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
#endif // FP16_AVAILABLE
}
-template <int cols_per_block, int parallel_blocks>
+template <int cols_per_block, int parallel_blocks, bool use_softcap>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
default: {
@@ -296,24 +309,39 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+ float softcap;
+ memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ }
}
diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu
index f3e68dbf..25908d7a 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f32.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu
@@ -4,7 +4,7 @@
#define FATTN_KQ_STRIDE_TILE_F32 32
-template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
+template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f32(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
+ // Skip unused kernel variants for faster compilation:
+ if (use_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32(
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
+ if (use_softcap) {
+ sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ }
+
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
@@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32(
}
}
-template <int cols_per_block, int parallel_blocks>
+template <int cols_per_block, int parallel_blocks, bool use_softcap>
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
default: {
@@ -292,21 +303,36 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
+ float softcap;
+ memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float));
+
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (softcap == 0.0f) {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
+ } else {
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
+ }
}
diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
index 02a4ad07..cf628dd5 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -1,7 +1,7 @@
#include "common.cuh"
#include "fattn-common.cuh"
-template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
+ // Skip unused kernel variants for faster compilation:
+ if (use_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
@@ -190,6 +197,9 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
+ if (use_softcap) {
+ sum = softcap*tanhf(sum);
+ }
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
if (ncols == 1) {
@@ -286,10 +296,10 @@ static __global__ void flash_attn_vec_ext_f16(
#endif // FP16_AVAILABLE
}
-template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
- fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
@@ -297,48 +307,71 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- ggml_tensor * KQV = dst;
- ggml_tensor * Q = dst->src[0];
- ggml_tensor * K = dst->src[1];
- ggml_tensor * V = dst->src[2];
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
+ float softcap;
+ memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
}
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
index 11a5e355..1aa88272 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
@@ -1,7 +1,7 @@
#include "common.cuh"
#include "fattn-common.cuh"
-template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
+ // Skip unused kernel variants for faster compilation:
+ if (use_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
@@ -180,6 +187,9 @@ static __global__ void flash_attn_vec_ext_f32(
for (int j = 0; j < ncols; ++j) {
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
+ if (use_softcap) {
+ sum = softcap*tanhf(sum);
+ }
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
@@ -267,10 +277,10 @@ static __global__ void flash_attn_vec_ext_f32(
}
}
-template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
- fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
@@ -278,44 +288,68 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- ggml_tensor * Q = dst->src[0];
- ggml_tensor * K = dst->src[1];
- ggml_tensor * V = dst->src[2];
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
+ float softcap;
+ memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (softcap == 0.0f) {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
+ }
}
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
index ae232224..efe78a2f 100644
--- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
@@ -6,7 +6,7 @@
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
-template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
+template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_softcap>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -21,6 +21,7 @@ static __global__ void flash_attn_ext_f16(
const float max_bias,
const float m0,
const float m1,
+ const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
@@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_MMA_AVAILABLE
+ // Skip unused kernel variants for faster compilation:
+ if (use_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
@@ -84,6 +91,7 @@ static __global__ void flash_attn_ext_f16(
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
+ const half2 softcap_2 = make_half2(softcap, softcap);
frag_b Q_b[D/16][ncols/frag_n];
@@ -194,6 +202,9 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+ if (use_softcap) {
+ KQ_f_tmp[k0/WARP_SIZE] = softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+ }
}
float KQ_max_new = KQ_max_f[j0/nwarps];
@@ -237,6 +248,16 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+ if (use_softcap) {
+ // There is no dedicated tangens hyperbolicus function for half2.
+ // Yes, and the code below can produce NaNs on overflow
+ KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
+ KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
+ /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+
+ KQ2_tmp[k0/WARP_SIZE] *= softcap_2;
+ }
+
}
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
@@ -435,20 +456,29 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+ float softcap;
+ memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float));
+
if (4*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 4;
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ fattn_kernel_t fattn_kernel = softcap == 0.0f ?
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 2;
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ fattn_kernel_t fattn_kernel = softcap == 0.0f ?
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
constexpr int parallel_blocks = 1;
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ fattn_kernel_t fattn_kernel = softcap == 0.0f ?
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
}
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 29f608b0..f87f33b3 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
if (precision != GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
@@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD) {
diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu
index c24abae1..6f3056e6 100644
--- a/ggml/src/ggml-cuda/softmax.cu
+++ b/ggml/src/ggml-cuda/softmax.cu
@@ -12,7 +12,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
}
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
-static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
@@ -44,7 +44,8 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;
- const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
+ const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + (mask ? slope*t2f32(mask[iy]) : 0.0f) :
+ scale*x[ix] + (mask ? slope*t2f32(mask[iy]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
@@ -116,7 +117,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
}
template<typename T>
-static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
@@ -134,36 +135,36 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) {
case 32:
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 64:
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 128:
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 256:
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 512:
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 1024:
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 2048:
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 4096:
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
default:
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(float);
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
}
}
@@ -197,10 +198,46 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (use_f16) {
const half * src1_dd = (const half *)src1_d;
- soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
} else {
const float * src1_dd = (const float *)src1_d;
- soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
+ }
+}
+
+void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ const float * src0_d = (const float *)src0->data;
+ const void * src1_d = src1 ? (const void *)src1->data : nullptr;
+
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src0->ne[1];
+
+ float params[4];
+ memcpy(params, dst->op_params, sizeof(params));
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+ //printf("%s: %g, %g, %g, %g, %p, %d\n", __func__, params[0], params[1], params[2], params[3], (const void *)src1, use_f16);
+
+ if (use_f16) {
+ const half * src1_dd = (const half *)src1_d;
+
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
+ } else {
+ const float * src1_dd = (const float *)src1_d;
+
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
}
}
diff --git a/ggml/src/ggml-cuda/softmax.cuh b/ggml/src/ggml-cuda/softmax.cuh
index 4ef4ff86..49a83dfa 100644
--- a/ggml/src/ggml-cuda/softmax.cuh
+++ b/ggml/src/ggml-cuda/softmax.cuh
@@ -3,3 +3,5 @@
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 1e940c5b..83bd76f9 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -67,6 +67,10 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
+ GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16,
+ GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4,
+ GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32,
+ GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4,
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
@@ -572,6 +576,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16, soft_cap_max_f16, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, soft_cap_max_f16_4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, soft_cap_max_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4, soft_cap_max_f32_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
@@ -872,6 +880,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_OP_SUM_ROWS:
return true;
case GGML_OP_SOFTCAP:
+ case GGML_OP_SOFT_CAP_MAX:
return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op);
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
@@ -1683,6 +1692,77 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
+ case GGML_OP_SOFT_CAP_MAX:
+ {
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+
+ int nth = 32; // SIMD width
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+ if (ne00%4 == 0) {
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
+ nth *= 2;
+ }
+ if (use_f16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline;
+ }
+ } else {
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
+ nth *= 2;
+ }
+ if (use_f16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline;
+ }
+ }
+
+ float scale;
+ float max_bias;
+ float s_before;
+ float s_after;
+
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+ memcpy(&s_before, ((int32_t *) dst->op_params) + 2, sizeof(s_before));
+ memcpy(&s_after, ((int32_t *) dst->op_params) + 3, sizeof(s_after));
+
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src0->ne[1];
+
+ const uint32_t n_head = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ if (id_src1) {
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
+ [encoder setBytes:&s_before length:sizeof(s_before) atIndex:10];
+ [encoder setBytes:&s_after length:sizeof(s_after ) atIndex:11];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:12];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
case GGML_OP_DIAG_MASK_INF:
{
const int n_past = ((int32_t *)(dst->op_params))[0];
@@ -2921,9 +3001,14 @@ static enum ggml_status ggml_metal_graph_compute(
float scale;
float max_bias;
+ float softcap;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+ memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap));
+ if (softcap != 0.0f) {
+ scale /= softcap;
+ }
const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
@@ -2997,7 +3082,8 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
+ [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28];
if (!use_vec_kernel) {
// half8x8 kernel
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 2a0e84a6..f9c88a37 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -661,6 +661,221 @@ kernel void kernel_soft_max_4(
}
}
+template<typename T>
+kernel void kernel_soft_cap_max(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant float & s_before,
+ constant float & s_after,
+ constant uint32_t & n_head_log2,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+ device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
+ device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+ float slope = 1.0f;
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const int64_t h = i02;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
+
+ // parallel max
+ float lmax = -INFINITY;
+
+ const float tot_scale = scale * s_after;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ lmax = MAX(lmax, precise::tanh(s_before*psrc0[i00])*tot_scale + (pmask ? slope*pmask[i00] : 0.0f));
+ }
+
+ // find the max value in the block
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
+
+ // parallel sum
+ float lsum = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+ lsum += exp_psrc0;
+ pdst[i00] = exp_psrc0;
+ }
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
+ float sum = simd_sum(lsum);
+
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
+
+ const float inv_sum = 1.0f/sum;
+
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ pdst[i00] *= inv_sum;
+ }
+}
+
+template<typename T>
+kernel void kernel_soft_cap_max_4(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant float & s_before,
+ constant float & s_after,
+ constant uint32_t & n_head_log2,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
+ device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+
+ float slope = 1.0f;
+
+ if (max_bias > 0.0f) {
+ const int64_t h = i02;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
+
+ const float tot_scale = scale * s_after;
+
+ // parallel max
+ float4 lmax4 = -INFINITY;
+
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ lmax4 = fmax(lmax4, precise::tanh(s_before*psrc4[i00])*tot_scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+ }
+
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
+
+ // parallel sum
+ float4 lsum4 = 0.0f;
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+ lsum4 += exp_psrc4;
+ pdst4[i00] = exp_psrc4;
+ }
+
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
+ float sum = simd_sum(lsum);
+
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
+
+ const float inv_sum = 1.0f/sum;
+
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ pdst4[i00] *= inv_sum;
+ }
+}
+
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
@@ -669,6 +884,14 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
+typedef decltype(kernel_soft_cap_max<float>) kernel_soft_cap_max_t;
+typedef decltype(kernel_soft_cap_max_4<float4>) kernel_soft_cap_max_4_t;
+
+template [[host_name("kernel_soft_cap_max_f16")]] kernel kernel_soft_cap_max_t kernel_soft_cap_max<half>;
+template [[host_name("kernel_soft_cap_max_f32")]] kernel kernel_soft_cap_max_t kernel_soft_cap_max<float>;
+template [[host_name("kernel_soft_cap_max_f16_4")]] kernel kernel_soft_cap_max_4_t kernel_soft_cap_max_4<half4>;
+template [[host_name("kernel_soft_cap_max_f32_4")]] kernel kernel_soft_cap_max_4_t kernel_soft_cap_max_4<float4>;
+
kernel void kernel_diag_mask_inf(
device const float * src0,
device float * dst,
@@ -2056,6 +2279,7 @@ typedef void (flash_attn_ext_f16_t)(
constant float & max_bias,
constant float & m0,
constant float & m1,
+ constant float & softcap,
constant uint32_t & n_head_log2,
threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -2094,6 +2318,7 @@ kernel void kernel_flash_attn_ext_f16(
constant float & max_bias,
constant float & m0,
constant float & m1,
+ constant float & softcap,
constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -2223,14 +2448,19 @@ kernel void kernel_flash_attn_ext_f16(
const short tx = tiisg%4;
const short ty = tiisg/4;
+ // mqk = mqk*scale
+ ss[8*cc + ty*TF + 2*tx + 0] *= scale;
+ ss[8*cc + ty*TF + 2*tx + 1] *= scale;
+
+ if (softcap != 0.0f) {
+ ss[8*cc + ty*TF + 2*tx + 0] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
+ ss[8*cc + ty*TF + 2*tx + 1] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
+ }
+
if (mask != q) {
// mqk = mqk*scale + mask*slope
- ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
- ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
- } else {
- // mqk = mqk*scale
- ss[8*cc + ty*TF + 2*tx + 0] *= scale;
- ss[8*cc + ty*TF + 2*tx + 1] *= scale;
+ ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
+ ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
}
}
}
@@ -2425,6 +2655,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
constant float & max_bias,
constant float & m0,
constant float & m1,
+ constant float & softcap,
constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -2560,7 +2791,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
// mqk = mqk*scale + mask*slope
if (tiisg == 0) {
- mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
+ mqk *= scale;
+ if (softcap != 0.0f) {
+ mqk = softcap*precise::tanh(mqk);
+ }
+ mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
ss4[cc] = mqk;
}
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 60a89591..cebac584 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -2889,6 +2889,41 @@ static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s
}
}
+static float ggml_vec_softcap_max_f32(const int n, float * x, float s_before, float s_after) {
+ int i = 0;
+ float max = -INFINITY;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ __m512 vs_before = _mm512_set1_ps(2.f*s_before);
+ __m512 vs_after = _mm512_set1_ps(s_after);
+ __m512 vmax = _mm512_set1_ps(-INFINITY);
+ for (; i + 15 < n; i += 16) {
+ __m512 y = ggml_v_softcap(_mm512_loadu_ps(x + i), vs_before, vs_after);
+ _mm512_storeu_ps(x + i, y);
+ vmax = _mm512_max_ps(vmax, y);
+ }
+ max = _mm512_reduce_max_ps(vmax);
+#elif defined(__AVX2__) && defined(__FMA__)
+ for (; i + 7 < n; i += 8) {
+ _mm256_storeu_ps(x + i, ggml_v_softcap(_mm256_loadu_ps(x + i), s_before, s_after));
+ }
+#elif defined(__SSE2__)
+ for (; i + 3 < n; i += 4) {
+ _mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after));
+ }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ float32x4_t vs_before = vdupq_n_f32(s_before);
+ float32x4_t vs_after = vdupq_n_f32(s_after);
+ for (; i + 3 < n; i += 4) {
+ vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after));
+ }
+#endif
+ for (; i < n; ++i) {
+ x[i] = s_after*tanhf(x[i]*s_before);
+ max = MAX(max, x[i]);
+ }
+ return max;
+}
+
inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
@@ -3144,6 +3179,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"ARGSORT",
"LEAKY_RELU",
"SOFTCAP",
+ "SOFT_CAP_MAX",
"FLASH_ATTN_EXT",
"FLASH_ATTN_BACK",
@@ -3171,7 +3207,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3233,6 +3269,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"argsort(x)",
"leaky_relu(x)",
"k2*tanh(k1*x)",
+ "soft_max(k2*tanh(k1*x))",
"flash_attn_ext(x)",
"flash_attn_back(x)",
@@ -3260,7 +3297,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -5963,6 +6000,72 @@ struct ggml_tensor * ggml_softcap_inplace(
return ggml_softcap_impl(ctx, a, s_before, s_after, true);
}
+static struct ggml_tensor * ggml_softcap_max_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias,
+ float s_before,
+ float s_after,
+ bool inplace) {
+ GGML_ASSERT(ggml_is_contiguous(a));
+ GGML_ASSERT(ggml_is_padded_1d(a));
+
+ if (mask) {
+ GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(mask));
+ GGML_ASSERT(ggml_is_matrix(mask));
+ GGML_ASSERT(mask->ne[0] == a->ne[0]);
+ GGML_ASSERT(mask->ne[1] >= a->ne[1]);
+ }
+
+ if (max_bias > 0.0f) {
+ GGML_ASSERT(mask);
+ }
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ float params[4] = {scale, max_bias, s_before, s_after};
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_SOFT_CAP_MAX;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = mask;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_softcap_max(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias,
+ float s_before,
+ float s_after) {
+ return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, false);
+}
+
+struct ggml_tensor * ggml_softcap_max_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias,
+ float s_before,
+ float s_after) {
+ return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, true);
+}
+
+
// ggml_set
static struct ggml_tensor * ggml_set_impl(
@@ -7493,7 +7596,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale,
- float max_bias) {
+ float max_bias,
+ float softcap) {
GGML_ASSERT(ggml_can_mul_mat(k, q));
// TODO: check if vT can be multiplied by (k*qT)
@@ -7520,7 +7624,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
- float params[] = { scale, max_bias };
+ float params[] = { scale, max_bias, softcap };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_FLASH_ATTN_EXT;
@@ -7540,7 +7644,7 @@ void ggml_flash_attn_ext_set_prec(
const int32_t prec_i32 = (int32_t) prec;
- ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
+ ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
}
// ggml_flash_attn_back
@@ -13618,6 +13722,122 @@ static void ggml_compute_forward_softcap(
}
}
+static void ggml_compute_forward_softcap_max_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ assert(ggml_is_contiguous(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ float values[4];
+ memcpy(values, dst->op_params, sizeof(values));
+ // values[0] -> scale
+ // values[1] -> max_bias
+ // values[2] -> s_before
+ // values[3] -> s_after
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ //const int64_t ne11 = src1 ? src1->ne[1] : 1;
+
+ // TODO: is this supposed to be ceil instead of floor?
+ // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
+ const uint32_t n_head = ne02;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+ const float m0 = powf(2.0f, -(values[1] ) / n_head_log2);
+ const float m1 = powf(2.0f, -(values[1] / 2.0f) / n_head_log2);
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ // ALiBi
+ const uint32_t h = (i1/ne01)%ne02; // head
+ const float slope = (values[1] > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+ float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+ float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
+
+ // broadcast the mask across rows
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+ float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+
+ ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]);
+
+ if (mp_f32) {
+ if (use_f16) {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
+ }
+ } else {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += slope*mp_f32[i];
+ }
+ }
+ }
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ //printf("p[%d] = %f\n", i, p[i]);
+ assert(!isnan(wp[i]));
+ }
+#endif
+
+ float max = -INFINITY;
+ ggml_vec_max_f32(nc, &max, wp);
+
+ ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
+ assert(sum > 0.0);
+
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(nc, dp, sum);
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ assert(!isnan(dp[i]));
+ assert(!isinf(dp[i]));
+ }
+#endif
+ }
+
+}
+
+static void ggml_compute_forward_softcap_max(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_softcap_max_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
// ggml_compute_forward_set
static void ggml_compute_forward_set_f32(
@@ -15919,9 +16139,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float scale = 1.0f;
float max_bias = 0.0f;
+ float softcap = 0.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+ memcpy(&softcap, (float *) dst->op_params + 2, sizeof(float));
+
+ if (softcap != 0.0f) {
+ scale /= softcap;
+ }
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
@@ -15985,7 +16211,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
- s = s*scale + mv; // scale KQ value and apply mask
+ s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask
const float Mold = M;
@@ -15994,7 +16220,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
- if (v->type== GGML_TYPE_F16) {
+ if (v->type == GGML_TYPE_F16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
@@ -16061,7 +16287,7 @@ static void ggml_compute_forward_flash_attn_ext(
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst) {
- switch (dst->op_params[2]) {
+ switch (dst->op_params[3]) {
case GGML_PREC_DEFAULT:
case GGML_PREC_F32:
{
@@ -17477,6 +17703,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_softcap(params, tensor);
} break;
+ case GGML_OP_SOFT_CAP_MAX:
+ {
+ ggml_compute_forward_softcap_max(params, tensor);
+ } break;
case GGML_OP_SET:
{
ggml_compute_forward_set(params, tensor);
@@ -18227,6 +18457,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
+ case GGML_OP_SOFT_CAP_MAX:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
case GGML_OP_SET:
{
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -19240,6 +19474,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_SCALE:
case GGML_OP_SOFTCAP:
case GGML_OP_SOFT_MAX:
+ case GGML_OP_SOFT_CAP_MAX:
{
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
} break;
diff --git a/src/llama.cpp b/src/llama.cpp
index 831f98dc..8a85144e 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -8290,7 +8290,8 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(v, "v", il);
- cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
+ hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
@@ -8324,10 +8325,12 @@ static struct ggml_tensor * llm_build_kqv(
}
if (hparams.attn_soft_cap) {
- kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ //kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias,
+ 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ } else {
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
}
-
- kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
GGML_ASSERT(kv.size == n_ctx);
@@ -13220,47 +13223,31 @@ struct llm_build_context {
0);
cb(k, "k", il);
- if (cparams.flash_attn) {
-
- // split cached v into n_head heads (not transposed)
- struct ggml_tensor * v =
- ggml_view_3d(ctx0, kv_self.v_l[il],
- n_embd_head_v, n_kv, n_head_kv,
- ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
- ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
- 0);
- cb(v, "v", il);
-
- cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
-
- cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
- } else {
- struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
- cb(kq, "kq", il);
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+ cb(kq, "kq", il);
- kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
- cb(kq, "kq_soft_max_ext", il);
+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
- GGML_ASSERT(kv_self.size == n_ctx);
+ GGML_ASSERT(kv_self.size == n_ctx);
- // split cached v into n_head heads
- struct ggml_tensor * v =
- ggml_view_3d(ctx0, kv_self.v_l[il],
- n_kv, n_embd_head_v, n_head_kv,
- ggml_element_size(kv_self.v_l[il])*n_ctx,
- ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
- 0);
- cb(v, "v", il);
+ // split cached v into n_head heads
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx0, kv_self.v_l[il],
+ n_kv, n_embd_head_v, n_head_kv,
+ ggml_element_size(kv_self.v_l[il])*n_ctx,
+ ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
+ 0);
+ cb(v, "v", il);
- struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
- cb(kqv, "kqv", il);
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+ cb(kqv, "kqv", il);
- struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
- cb(kqv_merged, "kqv_merged", il);
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+ cb(kqv_merged, "kqv_merged", il);
- cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
- cb(cur_attn, "kqv_merged_cont", il);
- }
+ cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+ cb(cur_attn, "kqv_merged_cont", il);
cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
model.layers[il].attn_sub_norm, NULL,
@@ -16811,12 +16798,6 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}
- if (params.flash_attn && model->hparams.attn_soft_cap) {
- LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
- params.flash_attn = false;
- }
-
-
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
params.flash_attn = false;
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index a2182c1b..f51ec5b8 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1652,19 +1652,20 @@ struct test_flash_attn_ext : public test_case {
const bool mask; // use mask
const float max_bias; // ALiBi
+ const float softcap; // Gemma-2
const ggml_type type_KV;
std::string vars() override {
- return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV);
+ return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, softcap, type_KV);
}
double max_nmse_err() override {
return 5e-4;
}
- test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
- : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
+ test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
+ : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), softcap(softcap), type_KV(type_KV) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@@ -1673,7 +1674,7 @@ struct test_flash_attn_ext : public test_case {
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
- ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
+ ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, softcap);
return out;
}
};
@@ -2434,11 +2435,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (bool mask : { true, false } ) {
for (float max_bias : { 0.0f, 8.0f }) {
if (!mask && max_bias > 0.0f) continue;
- for (int nh : { 32, }) {
- for (int kv : { 512, 1024, }) {
- for (int nb : { 1, 2, 4, 8, }) {
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV));
+ for (float softcap : {0.0f, 10.0f}) {
+ if (hs != 128 && softcap != 0.0f) continue;
+ for (int nh : { 32, }) {
+ for (int kv : { 512, 1024, }) {
+ for (int nb : { 1, 2, 4, 8, }) {
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
+ test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, softcap, type_KV));
+ }
}
}
}