From e0324285a569d0583cf2f4a07a2402221ee25f58 Mon Sep 17 00:00:00 2001 From: stduhpf Date: Tue, 16 Jan 2024 12:04:32 +0100 Subject: speculative : threading options (#4959) * speculative: expose draft threading * fix usage format * accept -td and -tbd args * speculative: revert default behavior when -td is unspecified * fix trailing whitespace --- common/common.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) (limited to 'common/common.cpp') diff --git a/common/common.cpp b/common/common.cpp index c11006bc..2b0865ff 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -167,6 +167,24 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { if (params.n_threads_batch <= 0) { params.n_threads_batch = std::thread::hardware_concurrency(); } + } else if (arg == "-td" || arg == "--threads-draft") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads_draft = std::stoi(argv[i]); + if (params.n_threads_draft <= 0) { + params.n_threads_draft = std::thread::hardware_concurrency(); + } + } else if (arg == "-tbd" || arg == "--threads-batch-draft") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads_batch_draft = std::stoi(argv[i]); + if (params.n_threads_batch_draft <= 0) { + params.n_threads_batch_draft = std::thread::hardware_concurrency(); + } } else if (arg == "-p" || arg == "--prompt") { if (++i >= argc) { invalid_param = true; @@ -845,6 +863,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); printf(" -tb N, --threads-batch N\n"); printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); + printf(" -td N, --threads-draft N"); + printf(" number of threads to use during generation (default: same as --threads)"); + printf(" -tbd N, --threads-batch-draft N\n"); + printf(" number of threads to use during batch and prompt processing (default: same as --threads-draft)\n"); printf(" -p PROMPT, --prompt PROMPT\n"); printf(" prompt to start generation with (default: empty)\n"); printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); -- cgit v1.2.3