diff options
author | Johannes Gäßler <johannesg@5d6.de> | 2024-05-18 12:36:25 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-18 12:36:25 +0200 |
commit | 133d99c59980139f5bb75922c8b5fca67d7ba9b8 (patch) | |
tree | 6a84c2c449dcd23e909db087b4594444b8622c71 /ggml-cuda/common.cuh | |
parent | cb42c294279bc4a0a4e926a7b5a5568049f12fa7 (diff) |
CUDA: deduplicate FlashAttention code (#7352)
Diffstat (limited to 'ggml-cuda/common.cuh')
-rw-r--r-- | ggml-cuda/common.cuh | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 784792ba..8f6fd71c 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -477,6 +477,17 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, - typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); +static __device__ __forceinline__ float get_alibi_slope( + const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1 +) { + if (max_bias <= 0.0f) { + return 1.0f; + } + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + return powf(base, exph); +} ////////////////////// |