summaryrefslogtreecommitdiff
path: root/examples/parallel/parallel.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-10-18 16:21:57 +0300
committerGitHub <noreply@github.com>2023-10-18 16:21:57 +0300
commit0e89203b517c95ec6675eda75d200a60d1e8921d (patch)
tree3aba40ef0362d061f240bd43c52e86a8f728f89d /examples/parallel/parallel.cpp
parentc67fe68e417f766970fb1feaf2e66458aa24116a (diff)
speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
Diffstat (limited to 'examples/parallel/parallel.cpp')
-rw-r--r--examples/parallel/parallel.cpp70
1 files changed, 28 insertions, 42 deletions
diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp
index 63ddcd8e..69f9526a 100644
--- a/examples/parallel/parallel.cpp
+++ b/examples/parallel/parallel.cpp
@@ -51,6 +51,12 @@ static std::vector<std::string> k_prompts = {
};
struct client {
+ ~client() {
+ if (ctx_sampling) {
+ llama_sampling_free(ctx_sampling);
+ }
+ }
+
int32_t id = 0;
llama_seq_id seq_id = -1;
@@ -68,7 +74,7 @@ struct client {
std::string prompt;
std::string response;
- std::vector<llama_token> tokens_prev;
+ struct llama_sampling_context * ctx_sampling = nullptr;
};
static void print_date_time() {
@@ -125,8 +131,6 @@ int main(int argc, char ** argv) {
params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
- llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);
-
// load the prompts from an external file if there are any
if (params.prompt.empty()) {
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@@ -147,20 +151,15 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n\n");
fflush(stderr);
- const int n_ctx = llama_n_ctx(ctx);
- const int n_vocab = llama_n_vocab(model);
+ const int n_ctx = llama_n_ctx(ctx);
std::vector<client> clients(n_clients);
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
- client.tokens_prev.resize(std::max(256, params.n_predict));
- std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
+ client.ctx_sampling = llama_sampling_init(params);
}
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
-
std::vector<llama_token> tokens_system;
tokens_system = ::llama_tokenize(ctx, k_system, true);
const int32_t n_tokens_system = tokens_system.size();
@@ -169,7 +168,7 @@ int main(int argc, char ** argv) {
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
- llama_batch batch = llama_batch_init(n_ctx, 0);
+ llama_batch batch = llama_batch_init(n_ctx, 0, 1);
int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
@@ -184,13 +183,8 @@ int main(int argc, char ** argv) {
{
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
- batch.n_tokens = n_tokens_system;
-
- for (int32_t i = 0; i < batch.n_tokens; ++i) {
- batch.token[i] = tokens_system[i];
- batch.pos[i] = i;
- batch.seq_id[i] = 0;
- batch.logits[i] = false;
+ for (int32_t i = 0; i < n_tokens_system; ++i) {
+ llama_batch_add(batch, tokens_system[i], i, { 0 }, false);
}
if (llama_decode(ctx, batch) != 0) {
@@ -209,7 +203,7 @@ int main(int argc, char ** argv) {
LOG_TEE("Processing requests ...\n\n");
while (true) {
- batch.n_tokens = 0;
+ llama_batch_clear(batch);
// decode any currently ongoing sequences
for (auto & client : clients) {
@@ -217,15 +211,11 @@ int main(int argc, char ** argv) {
continue;
}
- batch.token [batch.n_tokens] = client.sampled;
- batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
- batch.seq_id[batch.n_tokens] = client.id;
- batch.logits[batch.n_tokens] = true;
-
- client.n_decoded += 1;
client.i_batch = batch.n_tokens;
- batch.n_tokens += 1;
+ llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
+
+ client.n_decoded += 1;
}
if (batch.n_tokens == 0) {
@@ -250,18 +240,14 @@ int main(int argc, char ** argv) {
client.prompt = client.input + "\nAssistant:";
client.response = "";
- std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
+ llama_sampling_reset(client.ctx_sampling);
// do not prepend BOS because we have a system prompt!
std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
- batch.token [batch.n_tokens] = tokens_prompt[i];
- batch.pos [batch.n_tokens] = i + n_tokens_system;
- batch.seq_id[batch.n_tokens] = client.id;
- batch.logits[batch.n_tokens] = false;
- batch.n_tokens += 1;
+ llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
}
// extract the logits only for the last token
@@ -304,11 +290,12 @@ int main(int argc, char ** argv) {
llama_batch batch_view = {
n_tokens,
- batch.token + i,
+ batch.token + i,
nullptr,
- batch.pos + i,
- batch.seq_id + i,
- batch.logits + i,
+ batch.pos + i,
+ batch.n_seq_id + i,
+ batch.seq_id + i,
+ batch.logits + i,
0, 0, 0, // unused
};
@@ -341,7 +328,9 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
- const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
+ const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
+
+ llama_sampling_accept(client.ctx_sampling, ctx, id);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
@@ -349,11 +338,8 @@ int main(int argc, char ** argv) {
client.t_start_gen = ggml_time_us();
}
- // remember which tokens were sampled - used for repetition penalties during sampling
- client.tokens_prev.erase(client.tokens_prev.begin());
- client.tokens_prev.push_back(id);
-
const std::string token_str = llama_token_to_piece(ctx, id);
+
client.response += token_str;
client.sampled = id;
@@ -386,7 +372,7 @@ int main(int argc, char ** argv) {
n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded;
- llama_sampling_context_reset(ctx_sampling, client.seq_id);
+
client.seq_id = -1;
}