summaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
authorJan Boon <jan.boon@kaetemi.be>2024-04-08 20:43:30 +0800
committerGitHub <noreply@github.com>2024-04-08 15:43:30 +0300
commitbeea6e1b16e783a0886e78dec01002a8c00db24d (patch)
treea7365b1e93145b78a8b4be72df959239aa8c0f0d /llama.h
parent87fb5b4234d4b9c56ac94cf7aa229c8fd7defdb0 (diff)
llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id * remove trailing whitespace * respond error in case there's no space in the kv cache * add kv seq save restore to test case * add --slot-save-path arg to enable save restore and restrict save location * Returning 0 for some cases, instead of asserting. * cleanup error cases * rename sequence state functions * rename state get set functions * add previous function names back in with DEPRECATED notice * update doc * adjust endpoints to preferred style * fix restoring zero cell count * handle seq rm return value * unused param * keep in the size check * fix return types * add server test case for slot save restore * cleanup * add cake * cleanup style * add special * removing a whole sequence never fails * move sequence state file functionality from server to llama to match session api and add version tags * catch exceptions on save as well * error log messages * check types for stricter restore * update server doc * readme : update API changes date * strict filename validation * move include, reject bom as well * also reject empty filename * reject whitespace and trailing dot --------- Co-authored-by: Martin Evans <martindevans@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h73
1 files changed, 68 insertions, 5 deletions
diff --git a/llama.h b/llama.h
index 036b3268..2250130e 100644
--- a/llama.h
+++ b/llama.h
@@ -37,10 +37,14 @@
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
+#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 5
+#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
+#define LLAMA_STATE_SEQ_VERSION 1
+
#ifdef __cplusplus
extern "C" {
#endif
@@ -523,6 +527,7 @@ extern "C" {
struct llama_context * ctx);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
+ // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
@@ -594,34 +599,92 @@ extern "C" {
// Returns the maximum size in bytes of the state (rng, logits, embedding
// and kv_cache) - will often be smaller after compacting tokens
- LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
+ LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
+ LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
+ "use llama_state_get_size instead");
// Copies the state to the specified destination address.
// Destination needs to have allocated enough memory.
// Returns the number of bytes copied
- LLAMA_API size_t llama_copy_state_data(
+ LLAMA_API size_t llama_state_get_data(
struct llama_context * ctx,
uint8_t * dst);
+ LLAMA_API DEPRECATED(size_t llama_copy_state_data(
+ struct llama_context * ctx,
+ uint8_t * dst),
+ "use llama_state_get_data instead");
// Set the state reading from the specified address
// Returns the number of bytes read
- LLAMA_API size_t llama_set_state_data(
+ LLAMA_API size_t llama_state_set_data(
struct llama_context * ctx,
const uint8_t * src);
+ LLAMA_API DEPRECATED(size_t llama_set_state_data(
+ struct llama_context * ctx,
+ const uint8_t * src),
+ "use llama_state_set_data instead");
// Save/load session file
- LLAMA_API bool llama_load_session_file(
+ LLAMA_API bool llama_state_load_file(
struct llama_context * ctx,
const char * path_session,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out);
+ LLAMA_API DEPRECATED(bool llama_load_session_file(
+ struct llama_context * ctx,
+ const char * path_session,
+ llama_token * tokens_out,
+ size_t n_token_capacity,
+ size_t * n_token_count_out),
+ "use llama_state_load_file instead");
- LLAMA_API bool llama_save_session_file(
+ LLAMA_API bool llama_state_save_file(
struct llama_context * ctx,
const char * path_session,
const llama_token * tokens,
size_t n_token_count);
+ LLAMA_API DEPRECATED(bool llama_save_session_file(
+ struct llama_context * ctx,
+ const char * path_session,
+ const llama_token * tokens,
+ size_t n_token_count),
+ "use llama_state_save_file instead");
+
+ // Get the exact size needed to copy the KV cache of a single sequence
+ LLAMA_API size_t llama_state_seq_get_size(
+ struct llama_context * ctx,
+ llama_seq_id seq_id);
+
+ // Copy the KV cache of a single sequence into the specified buffer
+ LLAMA_API size_t llama_state_seq_get_data(
+ struct llama_context * ctx,
+ uint8_t * dst,
+ llama_seq_id seq_id);
+
+ // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
+ // Returns:
+ // - Positive: Ok
+ // - Zero: Failed to load
+ LLAMA_API size_t llama_state_seq_set_data(
+ struct llama_context * ctx,
+ const uint8_t * src,
+ llama_seq_id dest_seq_id);
+
+ LLAMA_API size_t llama_state_seq_save_file(
+ struct llama_context * ctx,
+ const char * filepath,
+ llama_seq_id seq_id,
+ const llama_token * tokens,
+ size_t n_token_count);
+
+ LLAMA_API size_t llama_state_seq_load_file(
+ struct llama_context * ctx,
+ const char * filepath,
+ llama_seq_id dest_seq_id,
+ llama_token * tokens_out,
+ size_t n_token_capacity,
+ size_t * n_token_count_out);
//
// Decoding