diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-02-17 23:04:16 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-17 23:04:16 +0200 |
commit | 8f1be0d42f23016cb6819dbae01126699c4bd9bc (patch) | |
tree | 4a142e745a73307190e9c5ef5c41aeb4aadaca7a /ggml.h | |
parent | 6e4e973b2615f8d390b1c4f4a7e05a119078bb0f (diff) |
ggml : add ALiBi support for ggml_soft_max_ext (#5488)
* ggml : avoid recomputing alibi slopes (CPU)
* llama : reuse hparams.f_max_alibi_bias in all cases
ggml-ci
* ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal)
ggml-ci
* ggml : handle all SRCs (do not break on first null)
ggml-ci
* tests : do not use slope for large soft_max
accumulates too much error
ggml-ci
* ggml : alternative ALiBi without extra tensor
We compute the slopes in the kernel
ggml-ci
* cuda : add ALiBi support in ggml_soft_max_ext
ggml-ci
* ggml : deprecate ggml_alibi
* ggml : support multi-sequence ALiBi (Metal)
ggml-ci
* cuda : add multi-seq ALiBi + remote F16 soft_max
ggml-ci
* ggml : update deprecation message
* ggml : fix pos ptr when no ALiBi
ggml-ci
* cuda : fix performance (pow -> powf)
* cuda : precompute ALiBi constants
* metal : pre-compute ALiBi slopes
ggml-ci
* llama : init kq_pos only if needed
ggml-ci
* test-backend-ops : add null pos test to soft_max
test-backend-ops : replace soft_max tests
ggml-ci
---------
Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'ggml.h')
-rw-r--r-- | ggml.h | 13 |
1 files changed, 9 insertions, 4 deletions
@@ -1383,13 +1383,17 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // fused soft_max(a*scale + mask) + // fused soft_max(a*scale + mask + pos[i]*(ALiBi slope)) // mask is optional + // pos is required when max_bias > 0.0f + // max_bias = 0.0f for no ALiBi GGML_API struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - float scale); + struct ggml_tensor * pos, + float scale, + float max_bias); GGML_API struct ggml_tensor * ggml_soft_max_back( struct ggml_context * ctx, @@ -1491,12 +1495,13 @@ extern "C" { // alibi position embedding // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_alibi( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi( struct ggml_context * ctx, struct ggml_tensor * a, int n_past, int n_head, - float bias_max); + float bias_max), + "use ggml_soft_max_ext instead (will be removed in Mar 2024)"); // clamp // in-place, returns view(a) |