summaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h34
1 files changed, 31 insertions, 3 deletions
diff --git a/llama.h b/llama.h
index 947284ea..ff131996 100644
--- a/llama.h
+++ b/llama.h
@@ -64,6 +64,15 @@ extern "C" {
LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
};
+ // note: these values should be synchronized with ggml_rope
+ // TODO: maybe move this enum to ggml.h (ggml_rope_type)
+ enum llama_rope_type {
+ LLAMA_ROPE_TYPE_NONE = -1,
+ LLAMA_ROPE_TYPE_NORM = 0,
+ LLAMA_ROPE_TYPE_NEOX = 2,
+ LLAMA_ROPE_TYPE_GLM = 4,
+ };
+
enum llama_token_type {
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
LLAMA_TOKEN_TYPE_NORMAL = 1,
@@ -360,6 +369,7 @@ extern "C" {
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
+ LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@@ -514,10 +524,12 @@ extern "C" {
llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
- // If the KV cache is RoPEd, the KV data is updated accordingly
+ // If the KV cache is RoPEd, the KV data is updated accordingly:
+ // - lazily on next llama_decode()
+ // - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
- LLAMA_API void llama_kv_cache_seq_shift(
+ LLAMA_API void llama_kv_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
@@ -525,7 +537,9 @@ extern "C" {
llama_pos delta);
// Integer division of the positions by factor of `d > 1`
- // If the KV cache is RoPEd, the KV data is updated accordingly
+ // If the KV cache is RoPEd, the KV data is updated accordingly:
+ // - lazily on next llama_decode()
+ // - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_div(
@@ -535,6 +549,20 @@ extern "C" {
llama_pos p1,
int d);
+ // Returns the largest position present in the KV cache for the specified sequence
+ LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
+ struct llama_context * ctx,
+ llama_seq_id seq_id);
+
+ // Defragment the KV cache
+ // This will be applied:
+ // - lazily on next llama_decode()
+ // - explicitly with llama_kv_cache_update()
+ LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
+
+ // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
+ LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
+
//
// State / sessions
//