diff options
Diffstat (limited to 'examples/simple/simple.cpp')
-rw-r--r-- | examples/simple/simple.cpp | 50 |
1 files changed, 21 insertions, 29 deletions
diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index b0f8e0fd..69a92cf7 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -6,28 +6,27 @@ #include <string> #include <vector> -int main(int argc, char ** argv) { - gpt_params params; +static void print_usage(int argc, char ** argv, const gpt_params & params) { + gpt_params_print_usage(argc, argv, params); - if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); - return 1 ; - } + LOG_TEE("\nexample usage:\n"); + LOG_TEE("\n %s -m model.gguf -p \"Hello my name is\" -n 32\n", argv[0]); + LOG_TEE("\n"); +} - if (argc >= 2) { - params.model = argv[1]; - } +int main(int argc, char ** argv) { + gpt_params params; - if (argc >= 3) { - params.prompt = argv[2]; - } + params.prompt = "Hello my name is"; + params.n_predict = 32; - if (params.prompt.empty()) { - params.prompt = "Hello my name is"; + if (!gpt_params_parse(argc, argv, params)) { + print_usage(argc, argv, params); + return 1; } // total length of the sequence including the prompt - const int n_len = 32; + const int n_predict = params.n_predict; // init LLM @@ -36,9 +35,7 @@ int main(int argc, char ** argv) { // initialize the model - llama_model_params model_params = llama_model_default_params(); - - // model_params.n_gpu_layers = 99; // offload all layers to the GPU + llama_model_params model_params = llama_model_params_from_gpt_params(params); llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); @@ -49,12 +46,7 @@ int main(int argc, char ** argv) { // initialize the context - llama_context_params ctx_params = llama_context_default_params(); - - ctx_params.seed = 1234; - ctx_params.n_ctx = 2048; - ctx_params.n_threads = params.n_threads; - ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + llama_context_params ctx_params = llama_context_params_from_gpt_params(params); llama_context * ctx = llama_new_context_with_model(model, ctx_params); @@ -69,14 +61,14 @@ int main(int argc, char ** argv) { tokens_list = ::llama_tokenize(ctx, params.prompt, true); const int n_ctx = llama_n_ctx(ctx); - const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); + const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size()); - LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req); + LOG_TEE("\n%s: n_predict = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, n_kv_req); // make sure the KV cache is big enough to hold all the prompt and generated tokens if (n_kv_req > n_ctx) { LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); - LOG_TEE("%s: either reduce n_len or increase n_ctx\n", __func__); + LOG_TEE("%s: either reduce n_predict or increase n_ctx\n", __func__); return 1; } @@ -115,7 +107,7 @@ int main(int argc, char ** argv) { const auto t_main_start = ggml_time_us(); - while (n_cur <= n_len) { + while (n_cur <= n_predict) { // sample the next token { auto n_vocab = llama_n_vocab(model); @@ -134,7 +126,7 @@ int main(int argc, char ** argv) { const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); // is it an end of generation? - if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { + if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { LOG_TEE("\n"); break; |