diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-27 17:40:59 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-27 17:40:59 +0300 |
commit | c7e99c88a2de7489ba2a1539b1a9025912010b70 (patch) | |
tree | 9976409b1e8fac1fc7486f2c5da05a33b8e229b5 /ggml/include/ggml.h | |
parent | bd99ed7d0afd2b12c0f5ff5c17b58486396dfe7e (diff) |
Faster Gemma2 (#27)
* soft_cap_max: initial CPU version of fused softcap + soft_max
With this vanilla CPU implementation I'm already getting a ~3% speedup
for Gemma-2-9b and a prompt of 8192 tokens.
* soft_cap_max: WIP - something is wrong with CUDA
* soft_cap_max: looks good on CPU and CUDA
* Add softcap to flash attention
Just CPU and CUDA for now (but, as we know, flash attention
on the CPU is useless in llama.cpp).
On CUDA this improves PP performance quite a bit, especially for
long contexts. E.g., for PP-16384, I now get 3777 t/s.
Without this change, one cannot use FA, and one gets 2300 t/s
(after fusing softcap and softmax), or 2000 t/s without the
fused softcap+softmax.
In comparison, mainline llama.cpp has PP-16384 = 1549 t/s before
PR-8542 (where Johannes Gaessler has also added softcap to FA),
and PP-16384 = 3097 t/s after this PR.
* soft_cap_max: Metal
* Flash attention with softcap: Metal
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/include/ggml.h')
-rw-r--r-- | ggml/include/ggml.h | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 17d3cb1a..1a4a516c 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -515,6 +515,7 @@ extern "C" { GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, GGML_OP_SOFTCAP, + GGML_OP_SOFT_CAP_MAX, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, @@ -1237,6 +1238,25 @@ extern "C" { float s_before, float s_after); + GGML_API struct ggml_tensor * ggml_softcap_max( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias, + float s_before, + float s_after); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_softcap_max_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias, + float s_before, + float s_after); + // b -> view(a,offset,nb1,nb2,3), return modified a GGML_API struct ggml_tensor * ggml_set( struct ggml_context * ctx, @@ -1791,7 +1811,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias); + float max_bias, + float softcap); GGML_API void ggml_flash_attn_ext_set_prec( struct ggml_tensor * a, |