summaryrefslogtreecommitdiff
path: root/examples/simple/simple.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/simple/simple.cpp')
-rw-r--r--examples/simple/simple.cpp136
1 files changed, 97 insertions, 39 deletions
diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp
index 440d22ec..1616a4a7 100644
--- a/examples/simple/simple.cpp
+++ b/examples/simple/simple.cpp
@@ -26,12 +26,18 @@ int main(int argc, char ** argv) {
params.prompt = "Hello my name is";
}
+ // total length of the sequence including the prompt
+ const int n_len = 32;
+
// init LLM
llama_backend_init(params.numa);
llama_context_params ctx_params = llama_context_default_params();
+ ctx_params.seed = 1234;
+ ctx_params.n_ctx = 2048;
+
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
if (model == NULL) {
@@ -41,20 +47,31 @@ int main(int argc, char ** argv) {
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
+ if (ctx == NULL) {
+ fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
+ return 1;
+ }
+
// tokenize the prompt
std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
- const int max_context_size = llama_n_ctx(ctx);
- const int max_tokens_list_size = max_context_size - 4;
+ const int n_ctx = llama_n_ctx(ctx);
+ const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
+
+ LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req);
- if ((int) tokens_list.size() > max_tokens_list_size) {
- fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size);
+ // make sure the KV cache is big enough to hold all the prompt and generated tokens
+ if (n_kv_req > n_ctx) {
+ LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
+ LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
return 1;
}
- fprintf(stderr, "\n\n");
+ // print the prompt token-by-token
+
+ fprintf(stderr, "\n");
for (auto id : tokens_list) {
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
@@ -62,63 +79,104 @@ int main(int argc, char ** argv) {
fflush(stderr);
- // main loop
+ // create a llama_batch with size 512
+ // we use this object to submit token data for decoding
- // The LLM keeps a contextual cache memory of previous token evaluation.
- // Usually, once this cache is full, it is required to recompute a compressed context based on previous
- // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
- // example, we will just stop the loop once this cache is full or once an end of stream is detected.
+ llama_batch batch = llama_batch_init(512, 0);
- const int n_gen = std::min(32, max_context_size);
+ // evaluate the initial prompt
+ batch.n_tokens = tokens_list.size();
- while (llama_get_kv_cache_token_count(ctx) < n_gen) {
- // evaluate the transformer
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
+ batch.token[i] = tokens_list[i];
+ batch.pos[i] = i;
+ batch.seq_id[i] = 0;
+ batch.logits[i] = false;
+ }
- if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) {
- fprintf(stderr, "%s : failed to eval\n", __func__);
- return 1;
- }
+ // llama_decode will output logits only for the last token of the prompt
+ batch.logits[batch.n_tokens - 1] = true;
+
+ if (llama_decode(ctx, batch, params.n_threads) != 0) {
+ LOG_TEE("%s: llama_decode() failed\n", __func__);
+ return 1;
+ }
+
+ // main loop
- tokens_list.clear();
+ int n_cur = batch.n_tokens;
+ int n_decode = 0;
+ const auto t_main_start = ggml_time_us();
+
+ while (n_cur <= n_len) {
// sample the next token
+ {
+ auto n_vocab = llama_n_vocab(ctx);
+ auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
- llama_token new_token_id = 0;
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
- auto logits = llama_get_logits(ctx);
- auto n_vocab = llama_n_vocab(ctx);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
+ }
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
- }
+ // sample the most likely token
+ const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
+
+ // is it an end of stream?
+ if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
+ LOG_TEE("\n");
+
+ break;
+ }
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
+ fflush(stdout);
- new_token_id = llama_sample_token_greedy(ctx , &candidates_p);
+ // prepare the next batch
+ batch.n_tokens = 0;
- // is it an end of stream ?
- if (new_token_id == llama_token_eos(ctx)) {
- fprintf(stderr, " [end of text]\n");
- break;
+ // push this new token for next evaluation
+ batch.token [batch.n_tokens] = new_token_id;
+ batch.pos [batch.n_tokens] = n_cur;
+ batch.seq_id[batch.n_tokens] = 0;
+ batch.logits[batch.n_tokens] = true;
+
+ batch.n_tokens += 1;
+
+ n_decode += 1;
}
- // print the new token :
- printf("%s", llama_token_to_piece(ctx, new_token_id).c_str());
- fflush(stdout);
+ n_cur += 1;
- // push this new token for next evaluation
- tokens_list.push_back(new_token_id);
+ // evaluate the current batch with the transformer model
+ if (llama_decode(ctx, batch, params.n_threads)) {
+ fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
+ return 1;
+ }
}
+ LOG_TEE("\n");
+
+ const auto t_main_end = ggml_time_us();
+
+ LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
+ __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
+
+ llama_print_timings(ctx);
+
+ fprintf(stderr, "\n");
+
+ llama_batch_free(batch);
+
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
- fprintf(stderr, "\n\n");
-
return 0;
}