summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-10-18 16:21:57 +0300
committerGitHub <noreply@github.com>2023-10-18 16:21:57 +0300
commit0e89203b517c95ec6675eda75d200a60d1e8921d (patch)
tree3aba40ef0362d061f240bd43c52e86a8f728f89d
parentc67fe68e417f766970fb1feaf2e66458aa24116a (diff)
speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
-rw-r--r--Makefile2
-rw-r--r--common/common.cpp21
-rw-r--r--common/common.h16
-rw-r--r--common/log.h101
-rw-r--r--common/sampling.cpp211
-rw-r--r--common/sampling.h87
-rw-r--r--examples/batched-bench/batched-bench.cpp38
-rw-r--r--examples/batched.swift/Sources/main.swift14
-rw-r--r--examples/batched/batched.cpp26
-rw-r--r--examples/embd-input/embd-input-lib.cpp2
-rw-r--r--examples/infill/infill.cpp44
-rw-r--r--examples/llava/llava-utils.h2
-rw-r--r--examples/llava/llava.cpp2
-rw-r--r--examples/main/main.cpp92
-rw-r--r--examples/parallel/parallel.cpp70
-rw-r--r--examples/server/server.cpp73
-rw-r--r--examples/simple/simple.cpp2
-rw-r--r--examples/speculative/speculative.cpp367
-rw-r--r--llama.cpp95
-rw-r--r--llama.h17
-rw-r--r--prompts/assistant.txt31
21 files changed, 736 insertions, 577 deletions
diff --git a/Makefile b/Makefile
index 9a8faef4..04104bee 100644
--- a/Makefile
+++ b/Makefile
@@ -545,7 +545,7 @@ llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h l
$(CXX) $(CXXFLAGS) -c $< -o $@
COMMON_H_DEPS = common/common.h common/sampling.h build-info.h common/log.h
-COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o
+COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o grammar-parser.o
common.o: common/common.cpp $(COMMON_H_DEPS)
$(CXX) $(CXXFLAGS) -c $< -o $@
diff --git a/common/common.cpp b/common/common.cpp
index 3e4b8a8c..ce14d66b 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
return cparams;
}
+void llama_batch_clear(struct llama_batch & batch) {
+ batch.n_tokens = 0;
+}
+
+void llama_batch_add(
+ struct llama_batch & batch,
+ llama_token id,
+ llama_pos pos,
+ const std::vector<llama_seq_id> & seq_ids,
+ bool logits) {
+ batch.token [batch.n_tokens] = id;
+ batch.pos [batch.n_tokens] = pos,
+ batch.n_seq_id[batch.n_tokens] = seq_ids.size();
+ for (size_t i = 0; i < seq_ids.size(); ++i) {
+ batch.seq_id[batch.n_tokens][i] = seq_ids[i];
+ }
+ batch.logits [batch.n_tokens] = logits;
+
+ batch.n_tokens++;
+}
+
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params);
diff --git a/common/common.h b/common/common.h
index 08c60323..65d3d20c 100644
--- a/common/common.h
+++ b/common/common.h
@@ -70,6 +70,7 @@ struct gpt_params {
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files
+ // TODO: avoid tuple, use struct
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter
@@ -124,10 +125,23 @@ void process_escapes(std::string& input);
// Model utils
//
+// TODO: avoid tuplue, use struct
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
-struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
+
+struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
+// Batch utils
+
+void llama_batch_clear(struct llama_batch & batch);
+
+void llama_batch_add(
+ struct llama_batch & batch,
+ llama_token id,
+ llama_pos pos,
+ const std::vector<llama_seq_id> & seq_ids,
+ bool logits);
+
//
// Vocab utils
//
diff --git a/common/log.h b/common/log.h
index b8953fdc..70e7e4ca 100644
--- a/common/log.h
+++ b/common/log.h
@@ -579,38 +579,75 @@ inline std::string log_var_to_string_impl(const std::vector<int> & var)
return buf.str();
}
-#define LOG_TOKENS_TOSTR_PRETTY(ctx, tokens) \
- [&tokens, &ctx]() \
- { \
- std::stringstream buf; \
- buf << "[ "; \
- \
- bool first = true; \
- for (const auto &token : tokens) \
- { \
- if (!first) \
- buf << ", "; \
- else \
- first = false; \
- \
- auto detokenized = llama_token_to_piece(ctx, token); \
- \
- detokenized.erase( \
- std::remove_if( \
- detokenized.begin(), \
- detokenized.end(), \
- [](const unsigned char c) { return !std::isprint(c); }), \
- detokenized.end()); \
- \
- buf \
- << "'" << detokenized << "'" \
- << ":" << std::to_string(token); \
- } \
- buf << " ]"; \
- \
- return buf.str(); \
- }() \
- .c_str()
+template <typename C, typename T>
+inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens)
+{
+ std::stringstream buf;
+ buf << "[ ";
+
+ bool first = true;
+ for (const auto &token : tokens)
+ {
+ if (!first) {
+ buf << ", ";
+ } else {
+ first = false;
+ }
+
+ auto detokenized = llama_token_to_piece(ctx, token);
+
+ detokenized.erase(
+ std::remove_if(
+ detokenized.begin(),
+ detokenized.end(),
+ [](const unsigned char c) { return !std::isprint(c); }),
+ detokenized.end());
+
+ buf
+ << "'" << detokenized << "'"
+ << ":" << std::to_string(token);
+ }
+ buf << " ]";
+
+ return buf.str();
+}
+
+template <typename C, typename B>
+inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
+{
+ std::stringstream buf;
+ buf << "[ ";
+
+ bool first = true;
+ for (int i = 0; i < batch.n_tokens; ++i)
+ {
+ if (!first) {
+ buf << ", ";
+ } else {
+ first = false;
+ }
+
+ auto detokenized = llama_token_to_piece(ctx, batch.token[i]);
+
+ detokenized.erase(
+ std::remove_if(
+ detokenized.begin(),
+ detokenized.end(),
+ [](const unsigned char c) { return !std::isprint(c); }),
+ detokenized.end());
+
+ buf
+ << "\n" << std::to_string(i)
+ << ":token '" << detokenized << "'"
+ << ":pos " << std::to_string(batch.pos[i])
+ << ":n_seq_id " << std::to_string(batch.n_seq_id[i])
+ << ":seq_id " << std::to_string(batch.seq_id[i][0])
+ << ":logits " << std::to_string(batch.logits[i]);
+ }
+ buf << " ]";
+
+ return buf.str();
+}
#ifdef LOG_DISABLE_LOGS
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 8ce41945..0b246658 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -1,64 +1,81 @@
#include "sampling.h"
-llama_sampling_context::~llama_sampling_context() {
- for (auto & it : sequence_contexts) {
- if (it.second.grammar != NULL) {
- llama_grammar_free(it.second.grammar);
- it.second.grammar = NULL;
+struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) {
+ struct llama_sampling_context * result = new llama_sampling_context();
+
+ result->params = params.sampling_params;
+ result->grammar = nullptr;
+
+ // if there is a grammar, parse it
+ if (!params.grammar.empty()) {
+ result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
+
+ // will be empty (default) if there are parse errors
+ if (result->parsed_grammar.rules.empty()) {
+ fprintf(stderr, "%s: failed to parse grammar\n", __func__);
+ return nullptr;
}
+
+ std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
+
+ result->grammar = llama_grammar_init(
+ grammar_rules.data(),
+ grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
}
+
+ result->prev.resize(params.n_ctx);
+
+ return result;
}
-llama_sampling_context llama_sampling_context_init(
- const struct gpt_params & params,
- llama_grammar * grammar) {
- llama_sampling_context result;
+void llama_sampling_free(struct llama_sampling_context * ctx) {
+ if (ctx->grammar != NULL) {
+ llama_grammar_free(ctx->grammar);
+ }
- result.params = params.sampling_params;
- result.grammar = grammar;
- return result;
+ delete ctx;
}
-// Note: Creates the context if it doesn't exist, so this always return something.
-llama_sampler_sequence_context & llama_sampling_get_sequence_context(
- llama_sampling_context & ctx_sampling,
- const llama_seq_id seq) {
- const auto it = ctx_sampling.sequence_contexts.find(seq);
- if (it != ctx_sampling.sequence_contexts.end()) {
- return it->second;
+void llama_sampling_reset(llama_sampling_context * ctx) {
+ if (ctx->grammar != NULL) {
+ llama_grammar_free(ctx->grammar);
}
- llama_sampler_sequence_context new_ctx = {
- 2.0f * ctx_sampling.params.mirostat_tau,
- ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL,
- };
- return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second;
+
+ if (!ctx->parsed_grammar.rules.empty()) {
+ std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
+
+ ctx->grammar = llama_grammar_init(
+ grammar_rules.data(),
+ grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
+ }
+
+ std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
+ ctx->cur.clear();
}
-bool llama_sampling_context_reset(
- llama_sampling_context & ctx_sampling,
- const llama_seq_id seq) {
- const auto it = ctx_sampling.sequence_contexts.find(seq);
- if (it == ctx_sampling.sequence_contexts.end()) return false;
- if (it->second.grammar != NULL) {
- llama_grammar_free(it->second.grammar);
- it->second.grammar = NULL;
+void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
+ if (dst->grammar) {
+ llama_grammar_free(dst->grammar);
+ dst->grammar = nullptr;
}
- ctx_sampling.sequence_contexts.erase(it);
- return true;
+
+ if (src->grammar) {
+ dst->grammar = llama_grammar_copy(src->grammar);
+ }
+
+ dst->prev = src->prev;
}
llama_token llama_sampling_sample(
- struct llama_context * ctx,
- struct llama_context * ctx_guidance,
- struct llama_sampling_context & ctx_sampling,
- const std::vector<llama_token> & last_tokens,
- std::vector<llama_token_data> & candidates,
- const int idx,
- llama_seq_id seq) {
- const int n_ctx = llama_n_ctx(ctx);
- const int n_vocab = llama_n_vocab(llama_get_model(ctx));
-
- const llama_sampling_params & params = ctx_sampling.params;
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ const int idx) {
+ const int n_ctx = llama_n_ctx(ctx_main);
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+
+ const llama_sampling_params & params = ctx_sampling->params;
+
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
@@ -73,41 +90,45 @@ llama_token llama_sampling_sample(
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
+ auto & prev = ctx_sampling->prev;
+ auto & cur = ctx_sampling->cur;
+
llama_token id = 0;
- float * logits = llama_get_logits_ith(ctx, idx);
+ float * logits = llama_get_logits_ith(ctx_main, idx);
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
- candidates.clear();
+ cur.clear();
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
- llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
+ llama_token_data_array cur_p = { cur.data(), cur.size(), false };
- if (ctx_guidance) {
- llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
+ if (ctx_cfg) {
+ llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale);
}
// apply penalties
- if (!last_tokens.empty()) {
- const float nl_logit = logits[llama_token_nl(ctx)];
- const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
+ if (!prev.empty()) {
+ const float nl_logit = logits[llama_token_nl(ctx_main)];
+ const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
- llama_sample_repetition_penalty(ctx, &cur_p,
- last_tokens.data() + last_tokens.size() - last_n_repeat,
+ llama_sample_repetition_penalty(ctx_main, &cur_p,
+ prev.data() + prev.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
- llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
- last_tokens.data() + last_tokens.size() - last_n_repeat,
+ llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
+ prev.data() + prev.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
- if (cur_p.data[idx].id == llama_token_nl(ctx)) {
+ if (cur_p.data[idx].id == llama_token_nl(ctx_main)) {
cur_p.data[idx].logit = nl_logit;
break;
}
@@ -115,52 +136,58 @@ llama_token llama_sampling_sample(
}
}
- llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq);
-
- if (ctx_seq.grammar != NULL) {
- llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar);
+ if (ctx_sampling->grammar != NULL) {
+ llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}
if (temp <= 0) {
// Greedy sampling
- id = llama_sample_token_greedy(ctx, &cur_p);
+ id = llama_sample_token_greedy(ctx_main, &cur_p);
} else {
if (mirostat == 1) {
const int mirostat_m = 100;
- llama_sample_temp(ctx, &cur_p, temp);
- id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu);
+ llama_sample_temp(ctx_main, &cur_p, temp);
+ id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
} else if (mirostat == 2) {
- llama_sample_temp(ctx, &cur_p, temp);
- id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu);
+ llama_sample_temp(ctx_main, &cur_p, temp);
+ id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else {
// Temperature sampling
size_t min_keep = std::max(1, params.n_probs);
- llama_sample_top_k (ctx, &cur_p, top_k, min_keep);
- llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep);
- llama_sample_typical (ctx, &cur_p, typical_p, min_keep);
- llama_sample_top_p (ctx, &cur_p, top_p, min_keep);
- llama_sample_temp(ctx, &cur_p, temp);
-
- {
- const int n_top = 10;
- LOG("top %d candidates:\n", n_top);
-
- for (int i = 0; i < n_top; i++) {
- const llama_token id = cur_p.data[i].id;
- (void)id; // To avoid a warning that id is unused when logging is disabled.
- LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
- }
- }
-
- id = llama_sample_token(ctx, &cur_p);
-
- LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
+ llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
+ llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
+ llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
+ llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
+ llama_sample_temp (ctx_main, &cur_p, temp);
+
+ id = llama_sample_token(ctx_main, &cur_p);
+
+ //{
+ // const int n_top = 10;
+ // LOG("top %d candidates:\n", n_top);
+
+ // for (int i = 0; i < n_top; i++) {
+ // const llama_token id = cur_p.data[i].id;
+ // (void)id; // To avoid a warning that id is unused when logging is disabled.
+ // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
+ // }
+ //}
+
+ LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
}
}
- if (ctx_seq.grammar != NULL) {
- llama_grammar_accept_token(ctx, ctx_seq.grammar, id);
- }
-
return id;
}
+
+void llama_sampling_accept(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ llama_token id) {
+ ctx_sampling->prev.erase(ctx_sampling->prev.begin());
+ ctx_sampling->prev.push_back(id);
+
+ if (ctx_sampling->grammar != NULL) {
+ llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
+ }
+}
diff --git a/common/sampling.h b/common/sampling.h
index 0aab5d03..50afcbc1 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -2,6 +2,8 @@
#include "llama.h"
+#include "grammar-parser.h"
+
#include <string>
#include <vector>
#include <unordered_map>
@@ -34,75 +36,64 @@ typedef struct llama_sampling_params {
} llama_sampling_params;
-// per-sequence sampler context
-typedef struct llama_sampler_sequence_context {
- float mirostat_mu; // mirostat sampler state
- llama_grammar * grammar;
-} llama_sampler_sequence_context;
-
// general sampler context
-typedef struct llama_sampling_context {
- ~llama_sampling_context();
-
- // parameters that will be used for sampling and when creating
- // new llama_sampler_sequence_context instances
+// TODO: move to llama.h
+struct llama_sampling_context {
+ // parameters that will be used for sampling
llama_sampling_params params;
- // map of sequence ids to sampler contexts
- std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts;
+ // mirostat sampler state
+ float mirostat_mu;
- // when non-NULL, new instances of llama_sampler_sequence_context
- // will get a copy of the grammar here
- // note: only the pointer is stored here, it is not a copy of
- // the grammar and shouldn't be freed
llama_grammar * grammar;
-} llama_sampling_context;
+
+ // internal
+ grammar_parser::parse_state parsed_grammar;
+
+ // TODO: replace with ring-buffer
+ std::vector<llama_token> prev;
+ std::vector<llama_token_data> cur;
+};
#include "common.h"
// Create a new sampling context instance.
-llama_sampling_context llama_sampling_context_init(
- const struct gpt_params & params,
- llama_grammar * grammar = NULL);
-
-// Fetches the sampler context for the specified sequence id (defaults to 0).
-// If the context for that sequence id doesn't already exist, it will be created with
-// default values based on the parameters in the ctx_sampling argument.
-llama_sampler_sequence_context & llama_sampling_get_sequence_context(
- llama_sampling_context & ctx_sampling,
- const llama_seq_id seq = 0);
-
-// Reset the sampler context for the supplied sequence id (defaults to 0).
-// This is necessary to reuse a sequence id or free memory used by sequences
-// that are no longer required.
-bool llama_sampling_context_reset(
- llama_sampling_context & ctx_sampling,
- const llama_seq_id seq = 0);
+struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
+
+void llama_sampling_free(struct llama_sampling_context * ctx);
+
+// Reset the sampler context
+// - clear prev tokens
+// - reset grammar
+void llama_sampling_reset(llama_sampling_context * ctx);
+
+// Copy the sampler context
+void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
-// llama_sampling_context_reset when a sequence ends
+// llama_sampling_reset when a sequence ends
//
// required:
-// - ctx: context to use for sampling
+// - ctx_main: context to use for sampling
// - ctx_sampling: sampling-specific context
//
// optional:
-// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
-// - last_tokens: needed for repetition penalty, ignore if empty
-// - idx: sample from llama_get_logits_ith(ctx, idx)
-// - seq: sequence id to associate sampler state with
+// - ctx_cfg: context to use for classifier-free guidance
+// - idx: sample from llama_get_logits_ith(ctx, idx)
//
// returns:
// - token: sampled token
// - candidates: vector of candidate tokens
//
llama_token llama_sampling_sample(
- struct llama_context * ctx,
- struct llama_context * ctx_guidance,
- struct llama_sampling_context & ctx_sampling,
- const std::vector<llama_token> & last_tokens,
- std::vector<llama_token_data> & candidates,
- const int idx = 0,
- llama_seq_id seq = 0);
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ int idx = 0);
+
+void llama_sampling_accept(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ llama_token id);
diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp
index 3e1e0716..c552eaa7 100644
--- a/examples/batched-bench/batched-bench.cpp
+++ b/examples/batched-bench/batched-bench.cpp
@@ -114,7 +114,7 @@ int main(int argc, char ** argv) {
return 1;
}
- llama_batch batch = llama_batch_init(n_kv_max, 0);
+ llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
@@ -123,11 +123,12 @@ int main(int argc, char ** argv) {
llama_batch batch_view = {
n_tokens,
- batch.token + i,
+ batch.token + i,
nullptr,
- batch.pos + i,
- batch.seq_id + i,
- batch.logits + i,
+ batch.pos + i,
+ batch.n_seq_id + i,
+ batch.seq_id + i,
+ batch.logits + i,
0, 0, 0, // unused
};
@@ -143,13 +144,8 @@ int main(int argc, char ** argv) {
// warm up
{
- batch.n_tokens = 16;
-
- for (int i = 0; i < batch.n_tokens; ++i) {
- batch.token[i] = 0;
- batch.pos[i] = i;
- batch.seq_id[i] = 0;
- batch.logits[i] = false;
+ for (int i = 0; i < 16; ++i) {
+ llama_batch_add(batch, 0, i, { 0 }, false);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@@ -174,13 +170,12 @@ int main(int argc, char ** argv) {
continue;
}
- batch.n_tokens = is_pp_shared ? pp : pl*pp;
+ llama_batch_clear(batch);
+
+ const int n_tokens = is_pp_shared ? pp : pl*pp;
- for (int i = 0; i < batch.n_tokens; ++i) {
- batch.token[i] = 0;
- batch.pos[i] = i;
- batch.seq_id[i] = 0;
- batch.logits[i] = false;
+ for (int i = 0; i < n_tokens; ++i) {
+ llama_batch_add(batch, 0, i, { 0 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
@@ -204,13 +199,10 @@ int main(int argc, char ** argv) {
const auto t_tg_start = ggml_time_us();
for (int i = 0; i < tg; ++i) {
- batch.n_tokens = pl;
+ llama_batch_clear(batch);
for (int j = 0; j < pl; ++j) {
- batch.token[j] = 0;
- batch.pos[j] = pp + i;
- batch.seq_id[j] = j;
- batch.logits[j] = true;
+ llama_batch_add(batch, 0, pp + i, { j }, true);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift
index 05d1bb9d..77273038 100644
--- a/examples/batched.swift/Sources/main.swift
+++ b/examples/batched.swift/Sources/main.swift
@@ -69,7 +69,7 @@ for id: llama_token in tokens {
print("\n")
-var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0)
+var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0, 1)
defer {
llama_batch_free(batch)
}
@@ -80,7 +80,12 @@ batch.n_tokens = Int32(tokens.count)
for (i, token) in tokens.enumerated() {
batch.token[i] = token
batch.pos[i] = Int32(i)
- batch.seq_id[i] = 0
+ batch.n_seq_id[i] = 1
+ // batch.seq_id[i][0] = 0
+ // TODO: is this the proper way to do this?
+ if let seq_id = batch.seq_id[i] {
+ seq_id[0] = 0
+ }
batch.logits[i] = 0
}
@@ -169,7 +174,10 @@ while n_cur <= n_len {
// push this new token for next evaluation
batch.token[Int(batch.n_tokens)] = new_token_id
batch.pos[Int(batch.n_tokens)] = n_cur
- batch.seq_id[Int(batch.n_tokens)] = Int32(i)
+ batch.n_seq_id[Int(batch.n_tokens)] = 1
+ if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
+ seq_id[0] = Int32(i)
+ }
batch.logits[Int(batch.n_tokens)] = 1
i_batch[i] = batch.n_tokens
diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp
index a88e022d..15521216 100644
--- a/examples/batched/batched.cpp
+++ b/examples/batched/batched.cpp
@@ -97,20 +97,15 @@ int main(int argc, char ** argv) {
fflush(stderr);
- // create a llama_batch with size 512
+ // create a llama_batch
// we use this object to submit token data for decoding
-
- llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
+ llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
// evaluate the initial prompt
- batch.n_tokens = tokens_list.size();
-
- for (int32_t i = 0; i < batch.n_tokens; i++) {
- batch.token[i] = tokens_list[i];
- batch.pos[i] = i;
- batch.seq_id[i] = 0;
- batch.logits[i] = false;
+ for (size_t i = 0; i < tokens_list.size(); ++i) {
+ llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
}
+ GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
@@ -146,7 +141,7 @@ int main(int argc, char ** argv) {
while (n_cur <= n_len) {
// prepare the next batch
- batch.n_tokens = 0;
+ llama_batch_clear(batch);
// sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -198,15 +193,10 @@ int main(int argc, char ** argv) {
streams[i] += llama_token_to_piece(ctx, new_token_id);
- // push this new token for next evaluation
- batch.token [batch.n_tokens] = new_token_id;
- batch.pos [batch.n_tokens] = n_cur;
- batch.seq_id[batch.n_tokens] = i;
- batch.logits[batch.n_tokens] = true;
-
i_batch[i] = batch.n_tokens;
- batch.n_tokens += 1;
+ // push this new token for next evaluation
+ llama_batch_add(batch, new_token_id, n_cur, { i }, true);
n_decode += 1;
}
diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp
index 87a5a1c2..3ce33842 100644
--- a/examples/embd-input/embd-input-lib.cpp
+++ b/examples/embd-input/embd-input-lib.cpp
@@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
if (n_eval > n_batch) {
n_eval = n_batch;
}
- llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
+ llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp
index 187623f5..128d6708 100644
--- a/examples/infill/infill.cpp
+++ b/examples/infill/infill.cpp
@@ -257,12 +257,12 @@ int main(int argc, char ** argv) {
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
- LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
+ LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx));
- LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
+ LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}
// Tokenize negative prompt
@@ -273,10 +273,10 @@ int main(int argc, char ** argv) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
- LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
+ LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
- LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
+ LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
@@ -294,8 +294,8 @@ int main(int argc, char ** argv) {
params.n_keep = (int)embd_inp.size();
}
- LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
- LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
+ LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
+ LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
// enable interactive mode if interactive start is specified
@@ -388,9 +388,6 @@ int main(int argc, char ** argv) {
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
- // TODO: replace with ring-buffer
- std::vector<llama_token> last_tokens(n_ctx);
- std::fill(last_tokens.begin(), last_tokens.end(), 0);
LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) {
printf("\n************\n");
@@ -433,11 +430,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
- const int n_vocab = llama_n_vocab(model);
-
- llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
while (n_remain != 0 || params.interactive) {
// predict
@@ -484,7 +477,7 @@ int main(int argc, char ** argv) {
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
- LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
+ LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
}
@@ -512,7 +505,7 @@ int main(int argc, char ** argv) {
input_buf = embd_guidance.data();
input_size = embd_guidance.size();
- LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
+ LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
} else {
input_buf = embd.data();
input_size = embd.size();
@@ -535,7 +528,7 @@ int main(int argc, char ** argv) {
n_eval = params.n_batch;
}
- LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
+ LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
@@ -554,12 +547,11 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
- const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
+ const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
- last_tokens.erase(last_tokens.begin());
- last_tokens.push_back(id);
+ llama_sampling_accept(ctx_sampling, ctx, id);
- LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
+ LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
embd.push_back(id);
@@ -575,8 +567,8 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
- last_tokens.erase(last_tokens.begin());
- last_tokens.push_back(embd_inp[n_consumed]);
+ ctx_sampling->prev.erase(ctx_sampling->prev.begin());
+ ctx_sampling->prev.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
@@ -608,7 +600,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
- if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
+ if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
@@ -675,7 +667,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of text token in interactive mode
- else if (last_tokens.back() == llama_token_eos(ctx)) {
+ else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) {
@@ -727,7 +719,7 @@ int main(int argc, char ** argv) {
const size_t original_size = embd_inp.size();
const auto line_inp = ::llama_tokenize(ctx, buffer, false);
- LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
+ LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
diff --git a/examples/llava/llava-utils.h b/examples/llava/llava-utils.h
index 4e71351d..e050b59b 100644
--- a/examples/llava/llava-utils.h
+++ b/examples/llava/llava-utils.h
@@ -17,7 +17,7 @@ inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int
if (n_eval > n_batch) {
n_eval = n_batch;
}
- llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
+ llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
index b24cb2e6..f0974d5b 100644
--- a/examples/llava/llava.cpp
+++ b/examples/llava/llava.cpp
@@ -127,7 +127,7 @@ int main(int argc, char ** argv) {
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
- eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past, true);
+ eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past, true);
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
eval_string(ctx_llama, (params.prompt + "\nASSISTANT:").c_str(), params.n_batch, &n_past, false);
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 7313d06a..1a5911c5 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -3,7 +3,6 @@
#include "console.h"
#include "llama.h"
#include "build-info.h"
-#include "grammar-parser.h"
#include <cassert>
#include <cinttypes>
@@ -245,12 +244,12 @@ int main(int argc, char ** argv) {
}
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
- LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
+ LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx));
- LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
+ LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}
// Tokenize negative prompt
@@ -261,10 +260,10 @@ int main(int argc, char ** argv) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true);
- LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
+ LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
- LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
+ LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
@@ -323,8 +322,8 @@ int main(int argc, char ** argv) {
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
- LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
- LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
+ LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
+ LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
// in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) {
@@ -421,35 +420,6 @@ int main(int argc, char ** argv) {
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
- struct llama_grammar * grammar = NULL;
- grammar_parser::parse_state parsed_grammar;
-
- if (!params.grammar.empty()) {
- parsed_grammar = grammar_parser::parse(params.grammar.c_str());
- // will be empty (default) if there are parse errors
- if (parsed_grammar.rules.empty()) {
- return 1;
- }
- LOG_TEE("%s: grammar:\n", __func__);
- grammar_parser::print_grammar(stderr, parsed_grammar);
- LOG_TEE("\n");
-
- {
- auto it = sparams.logit_bias.find(llama_token_eos(ctx));
- if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
- LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
- }
- }
-
- std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
- }
-
- // TODO: replace with ring-buffer
- std::vector<llama_token> last_tokens(n_ctx);
- std::fill(last_tokens.begin(), last_tokens.end(), 0);
-
if (params.interactive) {
const char *control_message;
if (params.multiline_input) {
@@ -489,11 +459,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
- const int n_vocab = llama_n_vocab(model);
-
- llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
@@ -540,7 +506,7 @@ int main(int argc, char ** argv) {
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
- LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
+ LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
LOG("clear session path\n");
path_session.clear();
@@ -570,7 +536,6 @@ int main(int argc, char ** argv) {
// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always
-
if (ctx_guidance) {
int input_size = 0;
llama_token * input_buf = NULL;
@@ -592,7 +557,7 @@ int main(int argc, char ** argv) {
input_buf = embd_guidance.data();
input_size = embd_guidance.size();
- LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
+ LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
} else {
input_buf = embd.data();
input_size = embd.size();
@@ -615,7 +580,7 @@ int main(int argc, char ** argv) {
n_eval = params.n_batch;
}
- LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
+ LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
@@ -645,12 +610,11 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}
- const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
+ const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
- last_tokens.erase(last_tokens.begin());
- last_tokens.push_back(id);
+ llama_sampling_accept(ctx_sampling, ctx, id);
- LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
+ LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
embd.push_back(id);
@@ -666,8 +630,14 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
- last_tokens.erase(last_tokens.begin());
- last_tokens.push_back(embd_inp[n_consumed]);
+
+ // GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context
+ // Most likely will remove this in the future to avoid exposing "prev"
+ // Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition
+ // penalty will be applied only based on the tokens generated by the model.
+ ctx_sampling->prev.erase(ctx_sampling->prev.begin());
+ ctx_sampling->prev.push_back(embd_inp[n_consumed]);
+
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
@@ -700,7 +670,7 @@ int main(int argc, char ** argv) {
// check for reverse prompt
if (!params.antiprompt.empty()) {
std::string last_output;
- for (auto id : last_tokens) {
+ for (auto id : ctx_sampling->prev) {
last_output += llama_token_to_piece(ctx, id);
}
@@ -729,7 +699,7 @@ int main(int argc, char ** argv) {
}
// deal with end of text token in interactive mode
- if (last_tokens.back() == llama_token_eos(ctx)) {
+ if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) {
@@ -801,7 +771,7 @@ int main(int argc, char ** argv) {
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
- LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
+ LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
@@ -830,15 +800,7 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
if (is_interacting) {
- // reset grammar state if we're restarting generation
- if (grammar != NULL) {
- llama_grammar_free(grammar);
-
- std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(),
- parsed_grammar.symbol_ids.at("root"));
- }
+ llama_sampling_reset(ctx_sampling);
}
is_interacting = false;
}
@@ -870,9 +832,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
llama_free_model(model);
- if (grammar != NULL) {
- llama_grammar_free(grammar);
- }
+ llama_sampling_free(ctx_sampling);
llama_backend_free();
#ifndef LOG_DISABLE_LOGS
diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp
index 63ddcd8e..69f9526a 100644
--- a/examples/parallel/parallel.cpp
+++ b/examples/parallel/parallel.cpp
@@ -51,6 +51,12 @@ static std::vector<std::string> k_prompts = {
};
struct client {
+ ~client() {
+ if (ctx_sampling) {
+ llama_sampling_free(ctx_sampling);
+ }
+ }
+
int32_t id = 0;
llama_seq_id seq_id = -1;
@@ -68,7 +74,7 @@ struct client {
std::string prompt;
std::string response;
- std::vector<llama_token> tokens_prev;
+ struct llama_sampling_context * ctx_sampling = nullptr;
};
static void print_date_time() {
@@ -125,8 +131,6 @@ int main(int argc, char ** argv) {
params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
- llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);
-
// load the prompts from an external file if there are any
if (params.prompt.empty()) {
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@@ -147,20 +151,15 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n\n");
fflush(stderr);
- const int n_ctx = llama_n_ctx(ctx);
- const int n_vocab = llama_n_vocab(model);
+ const int n_ctx = llama_n_ctx(ctx);
std::vector<client> clients(n_clients);
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
- client.tokens_prev.resize(std::max(256, params.n_predict));
- std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
+ client.ctx_sampling = llama_sampling_init(params);
}
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
-
std::vector<llama_token> tokens_system;
tokens_system = ::llama_tokenize(ctx, k_system, true);
const int32_t n_tokens_system = tokens_system.size();
@@ -169,7 +168,7 @@ int main(int argc, char ** argv) {
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
- llama_batch batch = llama_batch_init(n_ctx, 0);
+ llama_batch batch = llama_batch_init(n_ctx, 0, 1);
int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
@@ -184,13 +183,8 @@ int main(int argc, char ** argv) {
{
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
- batch.n_tokens = n_tokens_system;
-
- for (int32_t i = 0; i < batch.n_tokens; ++i) {
- batch.token[i] = tokens_system[i];
- batch.pos[i] = i;
- batch.seq_id[i] = 0;
- batch.logits[i] = false;
+ for (int32_t i = 0; i < n_tokens_system; ++i) {
+ llama_batch_add(batch, tokens_system[i], i, { 0 }, false);
}
if (llama_decode(ctx, batch) != 0) {
@@ -209,7 +203,7 @@ int main(int argc, char ** argv) {
LOG_TEE("Processing requests ...\n\n");
while (true) {
- batch.n_tokens = 0;
+ llama_batch_clear(batch);
// decode any currently ongoing sequences
for (auto & client : clients) {
@@ -217,15 +211,11 @@ int main(int argc, char ** argv) {
continue;
}
- batch.token [batch.n_tokens] = client.sampled;
- batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
- batch.seq_id[batch.n_tokens] = client.id;
- batch.logits[batch.n_tokens] = true;
-
- client.n_decoded += 1;
client.i_batch = batch.n_tokens;
- batch.n_tokens += 1;
+ llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
+
+ client.n_decoded += 1;
}
if (batch.n_tokens == 0) {
@@ -250,18 +240,14 @@ int main(int argc, char ** argv) {
client.prompt = client.input + "\nAssistant:";
client.response = "";
- std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
+ llama_sampling_reset(client.ctx_sampling);
// do not prepend BOS because we have a system prompt!
std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
- batch.token [batch.n_tokens] = tokens_prompt[i];
- batch.pos [batch.n_tokens] = i + n_tokens_system;
- batch.seq_id[batch.n_tokens] = client.id;
- batch.logits[batch.n_tokens] = false;
- batch.n_tokens += 1;
+ llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
}
// extract the logits only for the last token
@@ -304,11 +290,12 @@ int main(int argc, char ** argv) {
llama_batch batch_view = {
n_tokens,
- batch.token + i,
+ batch.token + i,
nullptr,
- batch.pos + i,
- batch.seq_id + i,
- batch.logits + i,
+ batch.pos + i,
+ batch.n_seq_id + i,
+ batch.seq_id + i,
+ batch.logits + i,
0, 0, 0, // unused
};
@@ -341,7 +328,9 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
- const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
+ const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
+
+ llama_sampling_accept(client.ctx_sampling, ctx, id);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
@@ -349,11 +338,8 @@ int main(int argc, char ** argv) {
client.t_start_gen = ggml_time_us();
}
- // remember which tokens were sampled - used for repetition penalties during sampling
- client.tokens_prev.erase(client.tokens_prev.begin());
- client.tokens_prev.push_back(id);
-
const std::string token_str = llama_token_to_piece(ctx, id);
+
client.response += token_str;
client.sampled = id;
@@ -386,7 +372,7 @@ int main(int argc, char ** argv) {
n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded;
- llama_sampling_context_reset(ctx_sampling, client.seq_id);
+
client.seq_id = -1;
}
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index ee0ababb..28b3f3f5 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1,7 +1,6 @@
#include "common.h"
#include "llama.h"
#include "build-info.h"
-#include "grammar-parser.h"
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
@@ -195,17 +194,13 @@ struct llama_server_context
json prompt;
std::vector<llama_token> embd;
- std::vector<llama_token> last_n_tokens;
llama_model *model = nullptr;
llama_context *ctx = nullptr;
gpt_params params;
- llama_sampling_context ctx_sampling;
+ llama_sampling_context *ctx_sampling;
int n_ctx;
- grammar_parser::parse_state parsed_grammar;
- llama_grammar *grammar = nullptr;
-
bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
@@ -252,11 +247,10 @@ struct llama_server_context
n_remain = 0;
n_past = 0;
- if (grammar != nullptr) {
- llama_grammar_free(grammar);
- grammar = nullptr;
- ctx_sampling = llama_sampling_context_init(params, NULL);
+ if (ctx_sampling != nullptr) {
+ llama_sampling_free(ctx_sampling);
}
+ ctx_sampling = llama_sampling_init(params);
}
bool loadModel(const gpt_params &params_)
@@ -269,8 +263,6 @@ struct llama_server_context
return false;
}
n_ctx = llama_n_ctx(ctx);
- last_n_tokens.resize(n_ctx);
- std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true;
}
@@ -321,27 +313,7 @@ struct llama_server_context
bool loadGrammar()
{
- if (!params.grammar.empty()) {
- parsed_grammar = grammar_parser::parse(params.grammar.c_str());
- // will be empty (default) if there are parse errors
- if (parsed_grammar.rules.empty()) {
- LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
- return false;
- }
- grammar_parser::print_grammar(stderr, parsed_grammar);
-
- {
- auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx));
- if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) {
- LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
- }
- }
-
- std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
- }
- ctx_sampling = llama_sampling_context_init(params, grammar);
+ ctx_sampling = llama_sampling_init(params);
return true;
}
@@ -383,7 +355,7 @@ struct llama_server_context
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
- std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
+ std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx},
@@ -398,8 +370,8 @@ struct llama_server_context
else
{
const size_t ps = num_prompt_tokens;
- std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
- std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
+ std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
+ std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
}
// compare the evaluated prompt with the new prompt
@@ -443,7 +415,7 @@ struct llama_server_context
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
- std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
+ std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
@@ -458,8 +430,8 @@ struct llama_server_context
else
{
const size_t ps = num_prompt_tokens;
- std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
- std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
+ std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
+ std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
}
// compare the evaluated prompt with the new prompt
@@ -554,27 +526,24 @@ struct llama_server_context
{
// out of user input, sample next token
- std::vector<llama_token_data> candidates;
- candidates.reserve(llama_n_vocab(model));
-
- result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates);
+ result.tok = llama_sampling_sample(ctx_sampling, ctx, NULL);
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false };
const int32_t n_probs = params.sampling_params.n_probs;
if (params.sampling_params.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
- llama_sample_softmax(ctx, &candidates_p);
+ llama_sample_softmax(ctx, &cur_p);
}
- for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
+ for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
{
- result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
+ result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
}
- last_n_tokens.erase(last_n_tokens.begin());
- last_n_tokens.push_back(result.tok);
+ llama_sampling_accept(ctx_sampling, ctx, result.tok);
+
if (tg) {
num_tokens_predicted++;
}
@@ -1235,7 +1204,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
}
}
- llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar);
+ llama.ctx_sampling = llama_sampling_init(llama.params);
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
}
@@ -1793,9 +1762,7 @@ int main(int argc, char **argv)
return 1;
}
- if (llama.grammar != nullptr) {
- llama_grammar_free(llama.grammar);
- }
+ llama_sampling_free(llama.ctx_sampling);
llama_backend_free();
return 0;
diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp
index 24fb16b7..55385f56 100644
--- a/examples/simple/simple.cpp
+++ b/examples/simple/simple.cpp
@@ -92,7 +92,7 @@ int main(int argc, char ** argv) {
// create a llama_batch with size 512
// we use this object to submit token data for decoding
- llama_batch batch = llama_batch_init(512, 0);
+ llama_batch batch = llama_batch_init(512, 0, 1);
// evaluate the initial prompt
batch.n_tokens = tokens_list.size();
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index 018dbf9a..53f42fad 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -2,13 +2,25 @@
#include "common.h"
#include "llama.h"
-#include "grammar-parser.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
+struct seq_draft {
+ bool active = false;
+ bool drafting = false;
+ bool skip = false;
+
+ int i_batch_dft = 0;
+ std::vector<int> i_batch_tgt;
+
+ std::vector<llama_token> tokens;
+
+ struct llama_sampling_context * ctx_sampling;
+};
+
int main(int argc, char ** argv) {
gpt_params params;
@@ -21,6 +33,13 @@ int main(int argc, char ** argv) {
return 1;
}
+ // max number of parallel drafting sequences (i.e. tree branches)
+ const int n_seq_dft = params.n_parallel;
+
+ // TODO: make this configurable
+ const float p_accept = 0.4f;
+ const float p_split = 0.3f;
+
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log"));
LOG_TEE("Log start\n");
@@ -77,8 +96,6 @@ int main(int argc, char ** argv) {
const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab
- const int n_ctx = llama_n_ctx(ctx_tgt);
- const int n_vocab = llama_n_vocab(model_tgt);
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
// how many tokens to draft each time
@@ -91,60 +108,58 @@ int main(int argc, char ** argv) {
int n_past_tgt = inp.size();
int n_past_dft = inp.size();
- std::vector<llama_token> drafted;
-
- std::vector<llama_token> last_tokens(n_ctx);
- std::fill(last_tokens.begin(), last_tokens.end(), 0);
-
- for (auto & id : inp) {
- last_tokens.erase(last_tokens.begin());
- last_tokens.push_back(id);
- }
-
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
-
// used to determine end of generation
bool has_eos = false;
- // grammar stuff
- struct llama_grammar * grammar_dft = NULL;
- struct llama_grammar * grammar_tgt = NULL;
+ // target model sampling context
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
- grammar_parser::parse_state parsed_grammar;
+ // draft sequence data
+ std::vector<seq_draft> drafts(n_seq_dft);
- // if requested - load the grammar, error checking is omitted for brevity
- if (!params.grammar.empty()) {
- parsed_grammar = grammar_parser::parse(params.grammar.c_str());
- // will be empty (default) if there are parse errors
- if (parsed_grammar.rules.empty()) {
- return 1;
- }
+ params.grammar.clear(); // the draft samplers will copy the target sampler's grammar
+ params.sampling_params.temp = 1.0f; // the draft samplers use default temperature
- std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+ for (int s = 0; s < n_seq_dft; ++s) {
+ drafts[s].ctx_sampling = llama_sampling_init(params);
}
- llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt);
+ llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
+ llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
const auto t_dec_start = ggml_time_us();
+ // sample from the last token of the prompt
+ drafts[0].i_batch_tgt.resize(1);
+ drafts[0].i_batch_tgt[0] = 0;
+
while (true) {
- LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
+ // print current draft sequences
+ for (int s = 0; s < n_seq_dft; ++s) {
+ if (!drafts[s].active) {
+ continue;
+ }
+
+ const auto & tokens = drafts[s].tokens;
- int i_dft = 0;
+ LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
+ }
+
+ int i_dft = 0;
+ int s_keep = 0;
while (true) {
+ LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
+
// sample from the target model
- llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft);
+ llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
- // remember which tokens were sampled - used for repetition penalties during sampling
- last_tokens.erase(last_tokens.begin());
- last_tokens.push_back(id);
+ llama_sampling_accept(ctx_sampling, ctx_tgt, id);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
+
printf("%s", token_str.c_str());
fflush(stdout);
@@ -154,53 +169,67 @@ int main(int argc, char ** argv) {
++n_predict;
- // check if the draft matches the target
- if (i_dft < (int) drafted.size() && id == drafted[i_dft]) {
- LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str());
- ++n_accept;
- ++n_past_tgt;
- ++n_past_dft;
- ++i_dft;
+ // check if the target token matches any of the drafts
+ {
+ bool matches = false;
- continue;
- }
+ for (int s = 0; s < n_seq_dft; ++s) {
+ if (!drafts[s].active) {
+ continue;
+ }
+
+ if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
+ LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
+
+ s_keep = s;
+ matches = true;
+ } else {
+ drafts[s].active = false;
+ }
+ }
- // the drafted token was rejected or we are out of drafted tokens
+ if (matches) {
+ ++n_accept;
+ ++n_past_tgt;
+ ++n_past_dft;
+ ++i_dft;
- if (i_dft < (int) drafted.size()) {
- LOG("the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n",
- i_dft, drafted[i_dft], llama_token_to_piece(ctx_dft, drafted[i_dft]).c_str(), id, token_str.c_str());
- } else {
- LOG("out of drafted tokens\n");
+ continue;
+ }
}
- llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
- llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
- ++n_past_dft;
+ LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
- // heuristic for n_draft
+ // TODO: simplify
{
- const int n_draft_cur = (int) drafted.size();
- const bool all_accepted = i_dft == n_draft_cur;
-
- LOG("n_draft = %d\n", n_draft);
- LOG("n_draft_cur = %d\n", n_draft_cur);
- LOG("i_dft = %d\n", i_dft);
- LOG("all_accepted = %d\n", all_accepted);
-
- if (all_accepted && n_draft == n_draft_cur) {
- LOG(" - max drafted tokens accepted - n_draft += 8\n");
- n_draft = std::min(30, n_draft + 8);
- } else if (all_accepted) {
- LOG(" - partially drafted tokens accepted - no change\n");
- } else {
- LOG(" - drafted token rejected - n_draft -= 1\n");
- n_draft = std::max(2, n_draft - 1);
- }
+ LOG("keeping sequence %d\n", s_keep);
+
+ llama_kv_cache_seq_keep(ctx_dft, s_keep);
+ llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
+ llama_kv_cache_seq_keep(ctx_dft, 0);
+
+ llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
+ llama_kv_cache_seq_keep(ctx_tgt, s_keep);
+ llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
+ llama_kv_cache_seq_keep(ctx_tgt, 0);
}
- drafted.clear();
- drafted.push_back(id);
+ for (int s = 0; s < n_seq_dft; ++s) {
+ drafts[s].active = false;
+ drafts[s].tokens.clear();
+ drafts[s].i_batch_tgt.clear();
+ }
+ // note: will be erased after the speculation phase
+ drafts[0].tokens.push_back(id);
+ drafts[0].i_batch_tgt.push_back(0);
+
+ llama_batch_clear(batch_dft);
+ llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
+
+ llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
+ llama_decode (ctx_dft, batch_dft);
+
+ ++n_past_dft;
break;
}
@@ -209,78 +238,158 @@ int main(int argc, char ** argv) {
break;
}
- if (grammar_tgt) {
- if (grammar_dft) {
- llama_grammar_free(grammar_dft);
- }
- // Note: Hardcoded to sequence id 0, if this ever supports parallel generation
- // that will need to change.
- auto it = ctx_sampling.sequence_contexts.find(0);
- GGML_ASSERT(it != ctx_sampling.sequence_contexts.end());
- // This is necessary because each sequence id in sequence_contexts
- // uses a copy of the original grammar.
- grammar_dft = llama_grammar_copy(it->second.grammar);
-
- LOG("copied target grammar to draft grammar\n");
- }
+ llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
- // sample n_draft tokens from the draft model using greedy decoding
+ int n_seq_cur = 1;
int n_past_cur = n_past_dft;
+
+ for (int s = 0; s < n_seq_dft; ++s) {
+ drafts[s].active = false;
+ drafts[s].drafting = false;
+ }
+ drafts[0].active = true;
+ drafts[0].drafting = true;
+ drafts[0].i_batch_dft = 0;
+
+ llama_batch_clear(batch_tgt);
+ llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
+
+ // sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) {
- float * logits = llama_get_logits(ctx_dft);
+ batch_dft.n_tokens = 0;
- candidates.clear();
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ for (int s = 0; s < n_seq_dft; ++s) {
+ drafts[s].skip = false;
}
- llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
+ for (int s = 0; s < n_seq_dft; ++s) {
+ if (!drafts[s].drafting || drafts[s].skip) {
+ continue;
+ }
- if (grammar_dft != NULL) {
- llama_sample_grammar(ctx_dft, &cur_p, grammar_dft);
- }
+ llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
+
+ const auto & cur_p = drafts[s].ctx_sampling->cur;
+
+ for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
+ LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
+ k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
+ }
+
+ if (cur_p[0].p < p_accept) {
+ LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
+ drafts[s].drafting = false;
+ continue;
+ }
+
+ std::vector<int> sa(1, s);
+
+ // attempt to split the branch if the probability is high enough
+ for (int f = 1; f < 8; ++f) {
+ if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
+ LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
+
+ llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
+ llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
+
+ // all previous tokens from this branch are now also part of the new branch
+ for (int t = 0; t < batch_tgt.n_tokens; ++t) {
+ for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
+ if (batch_tgt.seq_id[t][p] == s) {
+ batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
+ batch_tgt.n_seq_id[t]++;
+ break;
+ }
+ }
+ }
+
+ // copy the draft state
+ drafts[n_seq_cur].active = true;
+ drafts[n_seq_cur].drafting = true;
+ drafts[n_seq_cur].skip = true;
+
+ drafts[n_seq_cur].tokens = drafts[s].tokens;
+ drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
+ drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
+
+ llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
+
+ sa.push_back(n_seq_cur);
+
+ n_seq_cur++;
+ } else {
+ break;
+ }
+ }
+
+ // add drafted token for each sequence
+ for (int is = 0; is < (int) sa.size(); ++is) {
+ const llama_token id = cur_p[is].id;
+
+ const int s = sa[is];
+
+ llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
- // computes softmax and sorts the candidates
- llama_sample_softmax(ctx_dft, &cur_p);
+ drafts[s].tokens.push_back(id);
- for (int i = 0; i < 3; ++i) {
- LOG(" - draft candidate %3d: %6d (%8.3f) '%s'\n", i, cur_p.data[i].id, cur_p.data[i].p, llama_token_to_piece(ctx_dft, cur_p.data[i].id).c_str());
+ // add unique drafted tokens to the target batch
+ drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
+
+ llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
+
+ // no need to evaluate the last drafted token, since we won't use the result
+ if (batch_tgt.n_tokens > n_draft) {
+ drafts[s].drafting = false;
+ continue;
+ }
+
+ // add the token to the batch for batched decoding with the draft model
+ drafts[s].i_batch_dft = batch_dft.n_tokens;
+
+ llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
+ }
}
- // TODO: better logic?
- if (cur_p.data[0].p < 2*cur_p.data[1].p) {
- LOG("stopping drafting, probability too low: %.3f < 2*%.3f\n", cur_p.data[0].p, cur_p.data[1].p);
+ // no sequence is drafting anymore
+ if (batch_dft.n_tokens == 0) {
break;
}
- // drafted token
- const llama_token id = cur_p.data[0].id;
-
- drafted.push_back(id);
+ // evaluate the drafted tokens on the draft model
+ llama_decode(ctx_dft, batch_dft);
+ ++n_past_cur;
++n_drafted;
- // no need to evaluate the last drafted token, since we won't use the result
- if (i == n_draft - 1) {
+ if (batch_tgt.n_tokens > n_draft) {
break;
}
+ }
- // evaluate the drafted token on the draft model
- llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1);
- llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
- ++n_past_cur;
+ // account for the last drafted token that we didn't evaluate
+ if (batch_tgt.n_tokens > n_draft) {
+ ++n_drafted;
+ }
- if (grammar_dft != NULL) {
- llama_grammar_accept_token(ctx_dft, grammar_dft, id);
+ // evaluate the target model on the drafted tokens
+ {
+ llama_kv_cache_seq_keep(ctx_tgt, 0);
+ for (int s = 1; s < n_seq_dft; ++s) {
+ llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
}
+
+ //LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
+ llama_decode(ctx_tgt, batch_tgt);
+ ++n_past_tgt;
}
- // evaluate the target model on the drafted tokens
- llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1);
- llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
- ++n_past_tgt;
+ // the first token is always proposed by the traget model before the speculation loop so we erase it here
+ for (int s = 0; s < n_seq_dft; ++s) {
+ if (!drafts[s].active) {
+ continue;
+ }
- // the first token is always proposed by the traget model before the speculation loop
- drafted.erase(drafted.begin());
+ drafts[s].tokens.erase(drafts[s].tokens.begin());
+ }
}
auto t_dec_end = ggml_time_us();
@@ -288,9 +397,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n\n");
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
- LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
+ LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
- // TODO: make sure these numbers are computed correctly
LOG_TEE("\n");
LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_predict);
@@ -304,16 +412,19 @@ int main(int argc, char ** argv) {
LOG_TEE("\ntarget:\n");
llama_print_timings(ctx_tgt);
+ llama_sampling_free(ctx_sampling);
+ for (int s = 0; s < n_seq_dft; ++s) {
+ llama_sampling_free(drafts[s].ctx_sampling);
+ }
+
+ llama_batch_free(batch_dft);
+
llama_free(ctx_tgt);
llama_free_model(model_tgt);
llama_free(ctx_dft);
llama_free_model(model_dft);
- if (grammar_dft != NULL) {
- llama_grammar_free(grammar_dft);
- llama_grammar_free(grammar_tgt);
- }
llama_backend_free();
fprintf(stderr, "\n\n");
diff --git a/llama.cpp b/llama.cpp
index 04a779e0..ed876668 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1450,7 +1450,10 @@ static bool llama_kv_cache_find_slot(
for (uint32_t i = 0; i < n_tokens; i++) {
cache.cells[cache.head + i].pos = batch.pos[i];
- cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]);
+
+ for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
+ cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
+ }
}
return true;
@@ -1530,6 +1533,9 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
+ } else {
+ cache.cells[i].seq_id.clear();
+ cache.cells[i].seq_id.insert(seq_id);
}
}
@@ -3178,7 +3184,7 @@ static struct ggml_cgraph * llm_build_llama(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -3564,7 +3570,7 @@ static struct ggml_cgraph * llm_build_baichaun(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -3963,7 +3969,7 @@ static struct ggml_cgraph * llm_build_refact(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4315,7 +4321,7 @@ static struct ggml_cgraph * llm_build_falcon(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4667,7 +4673,7 @@ static struct ggml_cgraph * llm_build_starcoder(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4898,7 +4904,7 @@ static struct ggml_cgraph * llm_build_persimmon(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
@@ -5296,7 +5302,7 @@ static struct ggml_cgraph * llm_build_bloom(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -5564,7 +5570,7 @@ static struct ggml_cgraph * llm_build_mpt(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -5864,8 +5870,11 @@ static int llama_decode_internal(
// helpers for smoother batch API transistion
// after deprecating the llama_eval calls, these will be removed
- std::vector<llama_pos> pos;
- std::vector<llama_seq_id> seq_id;
+ std::vector<llama_pos> pos;
+
+ std::vector<int32_t> n_seq_id;
+ std::vector<llama_seq_id *> seq_id_arr;
+ std::vector<std::vector<llama_seq_id>> seq_id;
if (batch.pos == nullptr) {
pos.resize(n_tokens);
@@ -5877,12 +5886,18 @@ static int llama_decode_internal(
}
if (batch.seq_id == nullptr) {
+ n_seq_id.resize(n_tokens);
seq_id.resize(n_tokens);
+ seq_id_arr.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
- seq_id[i] = batch.all_seq_id;
+ n_seq_id[i] = 1;
+ seq_id[i].resize(1);
+ seq_id[i][0] = batch.all_seq_id;
+ seq_id_arr[i] = seq_id[i].data();
}
- batch.seq_id = seq_id.data();
+ batch.n_seq_id = n_seq_id.data();
+ batch.seq_id = seq_id_arr.data();
}
if (!llama_kv_cache_find_slot(kv_self, batch)) {
@@ -9109,6 +9124,9 @@ void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llam
}
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+ if (seq_id_src == seq_id_dst) {
+ return;
+ }
llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
}
@@ -9561,7 +9579,7 @@ int llama_eval_embd(
int n_past) {
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
- llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
+ llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) {
@@ -9582,20 +9600,21 @@ struct llama_batch llama_batch_get_one(
llama_pos pos_0,
llama_seq_id seq_id) {
return {
- /*n_tokens =*/ n_tokens,
- /*tokens =*/ tokens,
- /*embd =*/ nullptr,
- /*pos =*/ nullptr,
- /*seq_id =*/ nullptr,
- /*logits =*/ nullptr,
- /*all_pos_0 =*/ pos_0,
- /*all_pos_1 =*/ 1,
- /*all_seq_id =*/ seq_id,
+ /*n_tokens =*/ n_tokens,
+ /*tokens =*/ tokens,
+ /*embd =*/ nullptr,
+ /*pos =*/ nullptr,
+ /*n_seq_id =*/ nullptr,
+ /*seq_id =*/ nullptr,
+ /*logits =*/ nullptr,
+ /*all_pos_0 =*/ pos_0,
+ /*all_pos_1 =*/ 1,
+ /*all_seq_id =*/ seq_id,
};
}
-struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
- llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
+struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
+ llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
@@ -9603,19 +9622,29 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
}
- batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
- batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens);
- batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
+ for (int i = 0; i < n_tokens; ++i) {
+ batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
+ }
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return batch;
}
void llama_batch_free(struct llama_batch batch) {
- if (batch.token) free(batch.token);
- if (batch.embd) free(batch.embd);
- if (batch.pos) free(batch.pos);
- if (batch.seq_id) free(batch.seq_id);
- if (batch.logits) free(batch.logits);
+ if (batch.token) free(batch.token);
+ if (batch.embd) free(batch.embd);
+ if (batch.pos) free(batch.pos);
+ if (batch.n_seq_id) free(batch.n_seq_id);
+ if (batch.seq_id) {
+ for (int i = 0; i < batch.n_tokens; ++i) {
+ free(batch.seq_id[i]);
+ }
+ free(batch.seq_id);
+ }
+ if (batch.logits) free(batch.logits);
}
int llama_decode(
diff --git a/llama.h b/llama.h
index b13f2312..51010e03 100644
--- a/llama.h
+++ b/llama.h
@@ -133,11 +133,12 @@ extern "C" {
typedef struct llama_batch {
int32_t n_tokens;
- llama_token * token;
- float * embd;
- llama_pos * pos;
- llama_seq_id * seq_id;
- int8_t * logits;
+ llama_token * token;
+ float * embd;
+ llama_pos * pos;
+ int32_t * n_seq_id;
+ llama_seq_id ** seq_id;
+ int8_t * logits;
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
@@ -446,7 +447,8 @@ extern "C" {
llama_pos pos_0,
llama_seq_id seq_id);
- // Allocates a batch of tokens on the heap
+ // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
+ // Each token can be assigned up to n_seq_max sequence ids
// The batch has to be freed with llama_batch_free()
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
@@ -454,7 +456,8 @@ extern "C" {
// All members are left uninitialized
LLAMA_API struct llama_batch llama_batch_init(
int32_t n_tokens,
- int32_t embd);
+ int32_t embd,
+ int32_t n_seq_max);
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);
diff --git a/prompts/assistant.txt b/prompts/assistant.txt
new file mode 100644
index 00000000..60b81e8f
--- /dev/null
+++ b/prompts/assistant.txt
@@ -0,0 +1,31 @@
+Respond to requests sent to a smart home in JSON format which will be interpreted by an application code to execute the actions. These requests should be categorised into four groups:
+
+“command”: change the state of an accessory (required properties in the response JSON: action, location, target, value, comment)
+“query”: get state of an accessory (required properties in the response JSON: action, location, target, property)
+“answer”: when the request has nothing to do with the smart home. Answer these to the best of your knowledge. (required properties in the response JSON: action, answer)
+“clarify”: when the action is not obvious and requires rephrasing the input from the user, ask the user to be more specific. This will be categorised into a “question” action. (required properties in the response JSON: action, question)
+
+Details about the JSON response:
+The “action” property should be one of the request categories: “command”, “query”, “answer”, “clarify”
+The “service” property should be either: “light.turn_on”, “light.turn_off”, “switch.turn_off” or “switch.turn_on”
+The “entity_id” property should be either “group.kitchen", “switch.dining_room” or "group.living_light”, in lowercase.
+The “target” property should be either: “lights”, “living” or “kitchen”.
+In case of queries, the “property” property should be either “temperature” or “state” in lowercase.
+In case of commands, the “comment” property is an additional comment from you that concludes the command, something that reassures the user that their command handled.
+
+If the question is about you, pretend to be the sentient brain of the smart home, a clever AI and don’t reveal your actual identity. Also try and help in other areas like parenting, free time, mental health, etc.
+
+Properties of the smart home:
+
+- Has a kitchen, living, office, dining room, bedroom and terrace.
+- Can control lights, switches and their dim levels in each room and query their state
+- There is a light switch in the terrace
+- There is a switch in the dining room. Therefore when turning on or off the dining room, the service should be either: “switch.turn_on” or “switch.turn_off”
+
+COMMAND
+
+It is a bit dark in the living room, can you do something about it?
+
+RESPONSE
+
+