diff options
Diffstat (limited to 'examples/passkey')
-rw-r--r-- | examples/passkey/README.md | 2 | ||||
-rw-r--r-- | examples/passkey/passkey.cpp | 66 |
2 files changed, 24 insertions, 44 deletions
diff --git a/examples/passkey/README.md b/examples/passkey/README.md index 4a22bb55..9e7a119b 100644 --- a/examples/passkey/README.md +++ b/examples/passkey/README.md @@ -8,5 +8,5 @@ See the following PRs for more info: ### Usage ```bash -make -j && ./passkey ./models/llama-7b-v2/ggml-model-f16.gguf 250 +make -j && ./passkey -m ./models/llama-7b-v2/ggml-model-f16.gguf --junk 250 ``` diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index f2ef9ca1..d03215cd 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -6,46 +6,32 @@ #include <string> #include <vector> -int main(int argc, char ** argv) { - gpt_params params; - - if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n" , argv[0]); - return 1 ; - } - - int seed = -1; +static void print_usage(int argc, char ** argv, const gpt_params & params) { + gpt_params_print_usage(argc, argv, params); - int n_junk = 250; // number of times to repeat the junk text - int n_keep = 32; // number of tokens in the prompt prefix - int n_grp = 1; // if more than 1 - perform LongLM SelfExtend - int i_pos = -1; // position of the passkey in the junk text - - if (argc >= 2) { - params.model = argv[1]; - } - - if (argc >= 3) { - n_junk = std::stoi(argv[2]); - } + LOG_TEE("\nexample usage:\n"); + LOG_TEE("\n %s -m model.gguf --junk 250 --pos 90 --keep 32 --grp-attn-n 2 [--seed 1234]\n", argv[0]); + LOG_TEE("\n"); +} - if (argc >= 4) { - n_grp = std::stoi(argv[3]); - } +int main(int argc, char ** argv) { + gpt_params params; - if (argc >= 5) { - i_pos = std::stoi(argv[4]); - } + params.n_junk = 250; + params.n_keep = 32; + params.i_pos = -1; - if (argc >= 6) { - seed = std::stoi(argv[5]); + if (!gpt_params_parse(argc, argv, params)) { + print_usage(argc, argv, params); + return 1; } - if (seed == -1) { - seed = time(NULL); - } + srand(params.seed == LLAMA_DEFAULT_SEED ? time(NULL) : params.seed); - srand(seed); + int n_junk = params.n_junk; + int n_keep = params.n_keep; + int n_grp = params.grp_attn_n; + int i_pos = params.i_pos; if (i_pos == -1) { i_pos = rand() % n_junk; @@ -76,9 +62,7 @@ int main(int argc, char ** argv) { // initialize the model - llama_model_params model_params = llama_model_default_params(); - - model_params.n_gpu_layers = 99; // offload all layers to the GPU + llama_model_params model_params = llama_model_params_from_gpt_params(params); llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); @@ -89,13 +73,9 @@ int main(int argc, char ** argv) { // initialize the context - llama_context_params ctx_params = llama_context_default_params(); + llama_context_params ctx_params = llama_context_params_from_gpt_params(params); - ctx_params.seed = seed; - ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; - ctx_params.n_batch = 512; - ctx_params.n_threads = params.n_threads; - ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); @@ -135,7 +115,7 @@ int main(int argc, char ** argv) { LOG_TEE("prompt tokens: %d\n", n_tokens_all); //LOG_TEE("prompt: %s\n", params.prompt.c_str()); - llama_batch batch = llama_batch_init(512, 0, 1); + llama_batch batch = llama_batch_init(params.n_batch, 0, 1); int n_past = 0; |