diff options
Diffstat (limited to 'examples')
-rwxr-xr-x | examples/chat-persistent.sh | 8 | ||||
-rw-r--r-- | examples/main/main.cpp | 3 | ||||
-rw-r--r-- | examples/parallel/parallel.cpp | 2 | ||||
-rw-r--r-- | examples/server/server.cpp | 2 | ||||
-rw-r--r-- | examples/speculative/speculative.cpp | 6 |
5 files changed, 12 insertions, 9 deletions
diff --git a/examples/chat-persistent.sh b/examples/chat-persistent.sh index e0c251e5..22f5b83d 100755 --- a/examples/chat-persistent.sh +++ b/examples/chat-persistent.sh @@ -9,7 +9,7 @@ if [[ -z "${PROMPT_CACHE_FILE+x}" || -z "${CHAT_SAVE_DIR+x}" ]]; then exit 1 fi -MODEL="${MODEL:-./models/13B/ggml-model-q4_0.bin}" +MODEL="${MODEL:-./models/llama-13b/ggml-model-q4_0.gguf}" PROMPT_TEMPLATE="${PROMPT_TEMPLATE:-./prompts/chat.txt}" USER_NAME="${USER_NAME:-User}" AI_NAME="${AI_NAME:-ChatLLaMa}" @@ -61,9 +61,9 @@ fi if [[ ! -e "$PROMPT_CACHE_FILE" ]]; then echo 'Prompt cache does not exist, building...' - # Default batch_size to 8 here for better user feedback during initial prompt processing + # Default batch_size to 64 here for better user feedback during initial prompt processing ./main 2>>"$LOG" \ - --batch_size 8 \ + --batch_size 64 \ "${OPTS[@]}" \ --prompt-cache "$PROMPT_CACHE_FILE" \ --file "$CUR_PROMPT_FILE" \ @@ -132,7 +132,7 @@ while read -e line; do # HACK get num tokens from debug message # TODO get both messages in one go if ! session_size_msg="$(tail -n30 "$LOG" | grep -oE "$SESSION_SIZE_MSG_PATTERN")" || - ! sample_time_msg="$( tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then + ! sample_time_msg="$(tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then echo >&2 "Couldn't get number of tokens from ./main output!" exit 1 fi diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3a4ed3f7..7367ae36 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -543,6 +543,9 @@ int main(int argc, char ** argv) { if (i > 0) { embd.erase(embd.begin(), embd.begin() + i); } + + // remove any "future" tokens that we might have inherited from the session from the KV cache + llama_kv_cache_tokens_rm(ctx, n_past, -1); } // evaluate tokens in batches diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 0434ded2..ffd7b1db 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -332,7 +332,7 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx); + llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6dda5e36..921eb5da 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -448,7 +448,7 @@ struct llama_server_context n_past = common_part(embd, prompt_tokens); // since #3228 we now have to manually manage the KV cache - llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx); + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); embd = prompt_tokens; if (n_past == num_prompt_tokens) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index c5e5b234..75a2e5e2 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { LOG("out of drafted tokens\n"); } - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0)); ++n_past_dft; @@ -257,7 +257,7 @@ int main(int argc, char ** argv) { } // evaluate the drafted token on the draft model - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1); llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0)); ++n_past_cur; @@ -267,7 +267,7 @@ int main(int argc, char ** argv) { } // evaluate the target model on the drafted tokens - llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx); + llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1); llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0)); ++n_past_tgt; |