diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-20 17:15:47 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-20 17:15:47 +0300 |
commit | d259a50ca6fd3a0821abe6a16b73c0b19c5b4651 (patch) | |
tree | 4f83bbbbbbd9323192d8c0bceb51de5b0fb620c2 /ggml/include | |
parent | a325745000114a43c1546323f91720db503ed0a9 (diff) |
Fused soft cap and SIMD-ified GeLU (#9)
* Softcap: WIP
Fuses scale + tanh + scale as used for softcaping in some
models.
Just CPU for now. ~1.4% for PP-512 on Gemma2-9b, no effect on TG.
Somewhat surprisingly the improvement does not increase as I
go to longer contexts. Gemma2 does softcap on K*Q, which grows
quadratically with context length, so I would have thought
the benefit from fusing scale, tanh, scale would increase.
But no, no luck.
* softcap: CUDA
* softcap: CUDA
~1% speedup for Gemma2-9b
* softcap: Metal and NEON
About 1% speedup.
* Simdified gelu
Gives ~1% speedup for Gemma2-9b prompt processing on AVX512/AVX2.
It looks like the gelu operation is memory bound on my CPU's
after SIMD-ifying it. By not using the 128 kb gelu lookup table
we gain a small advantage.
On the M2-Max the lookup table is slightly faster than the SIMD
version, so left the lookup table for ARM_NEON.
* softcap, tanh: avoid NaNs for large arguments (AVX2, AVX512)
Not that I have encountered this in practice, but just to be sure.
This does it for AVX512 and AVX2, still need a guard for ARM_NEON.
* llama-bench: add ability to turn off warmup runs
So we don't need to wait forever on, e.g., benchmarks involving
long contexts.
* softcap, tanh: avoid NaNs for large arguments (NEON)
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/include')
-rw-r--r-- | ggml/include/ggml.h | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 026993db..17d3cb1a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -514,6 +514,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, + GGML_OP_SOFTCAP, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, @@ -1223,6 +1224,19 @@ extern "C" { struct ggml_tensor * a, float s); + GGML_API struct ggml_tensor * ggml_softcap( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_softcap_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after); + // b -> view(a,offset,nb1,nb2,3), return modified a GGML_API struct ggml_tensor * ggml_set( struct ggml_context * ctx, |