diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-06-04 21:23:39 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-04 21:23:39 +0300 |
commit | 1442677f92e45a475be7b4d056e3633d1d6f813b (patch) | |
tree | d9dbb111ccaedc44cba527dbddd90bedd1e04ea8 /examples/retrieval/retrieval.cpp | |
parent | 554c247caffed64465f372661f2826640cb10430 (diff) |
common : refactor cli arg parsing (#7675)
* common : gpt_params_parse do not print usage
* common : rework usage print (wip)
* common : valign
* common : rework print_usage
* infill : remove cfg support
* common : reorder args
* server : deduplicate parameters
ggml-ci
* common : add missing header
ggml-ci
* common : remote --random-prompt usages
ggml-ci
* examples : migrate to gpt_params
ggml-ci
* batched-bench : migrate to gpt_params
* retrieval : migrate to gpt_params
* common : change defaults for escape and n_ctx
* common : remove chatml and instruct params
ggml-ci
* common : passkey use gpt_params
Diffstat (limited to 'examples/retrieval/retrieval.cpp')
-rw-r--r-- | examples/retrieval/retrieval.cpp | 90 |
1 files changed, 16 insertions, 74 deletions
diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 4e753070..55b7b2f7 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -4,72 +4,12 @@ #include <algorithm> #include <fstream> -struct retrieval_params { - std::vector<std::string> context_files; // context files to embed - int32_t chunk_size = 64; // chunk size for context embedding - std::string chunk_separator = "\n"; // chunk separator for context embedding -}; +static void print_usage(int argc, char ** argv, const gpt_params & params) { + gpt_params_print_usage(argc, argv, params); -static void retrieval_params_print_usage(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & params) { - gpt_params_print_usage(argc, argv, gpt_params); - printf("retrieval options:\n"); - printf(" --context-file FNAME file containing context to embed.\n"); - printf(" specify multiple files by providing --context-file option multiple times.\n"); - printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size); - printf(" --chunk-separator STRING\n"); - printf(" string to separate chunks (default: \"\\n\")\n"); - printf("\n"); -} - -static void retrieval_params_parse(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & retrieval_params) { - int i = 1; - std::string arg; - while (i < argc) { - arg = argv[i]; - bool invalid_gpt_param = false; - if(gpt_params_find_arg(argc, argv, argv[i], gpt_params, i, invalid_gpt_param)) { - if (invalid_gpt_param) { - fprintf(stderr, "error: invalid argument: %s\n", arg.c_str()); - retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); - exit(1); - } - // option was parsed by gpt_params_find_arg - } else if (arg == "--context-file") { - if (++i >= argc) { - fprintf(stderr, "error: missing argument for --context-file\n"); - retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); - exit(1); - } - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); - exit(1); - } - // store the external file name in params - retrieval_params.context_files.push_back(argv[i]); - } else if (arg == "--chunk-size") { - if (++i >= argc) { - fprintf(stderr, "error: missing argument for --chunk-size\n"); - retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); - exit(1); - } - retrieval_params.chunk_size = std::stoi(argv[i]); - } else if (arg == "--chunk-separator") { - if (++i >= argc) { - fprintf(stderr, "error: missing argument for --chunk-separator\n"); - retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); - exit(1); - } - retrieval_params.chunk_separator = argv[i]; - } else { - // unknown argument - fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); - retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params); - exit(1); - } - i++; - } + LOG_TEE("\nexample usage:\n"); + LOG_TEE("\n %s --model ./models/bge-base-en-v1.5-f16.gguf --top-k 3 --context-file README.md --context-file License --chunk-size 100 --chunk-separator .\n", argv[0]); + LOG_TEE("\n"); } struct chunk { @@ -171,33 +111,35 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu int main(int argc, char ** argv) { gpt_params params; - retrieval_params retrieval_params; - retrieval_params_parse(argc, argv, params, retrieval_params); + if (!gpt_params_parse(argc, argv, params)) { + print_usage(argc, argv, params); + return 1; + } // For BERT models, batch size must be equal to ubatch size params.n_ubatch = params.n_batch; + params.embedding = true; - if (retrieval_params.chunk_size <= 0) { + if (params.chunk_size <= 0) { fprintf(stderr, "chunk_size must be positive\n"); return 1; } - if (retrieval_params.context_files.empty()) { + if (params.context_files.empty()) { fprintf(stderr, "context_files must be specified\n"); return 1; } - params.embedding = true; print_build_info(); printf("processing files:\n"); - for (auto & context_file : retrieval_params.context_files) { + for (auto & context_file : params.context_files) { printf("%s\n", context_file.c_str()); } std::vector<chunk> chunks; - for (auto & context_file : retrieval_params.context_files) { - std::vector<chunk> file_chunk = chunk_file(context_file, retrieval_params.chunk_size, retrieval_params.chunk_separator); + for (auto & context_file : params.context_files) { + std::vector<chunk> file_chunk = chunk_file(context_file, params.chunk_size, params.chunk_separator); chunks.insert(chunks.end(), file_chunk.begin(), file_chunk.end()); } printf("Number of chunks: %ld\n", chunks.size()); @@ -242,7 +184,7 @@ int main(int argc, char ** argv) { return 1; } // add eos if not present - if (inp.empty() || inp.back() != llama_token_eos(model)) { + if (llama_token_eos(model) >= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) { inp.push_back(llama_token_eos(model)); } chunk.tokens = inp; |