diff options
author | stduhpf <stephduh@live.fr> | 2024-01-16 12:04:32 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-16 13:04:32 +0200 |
commit | e0324285a569d0583cf2f4a07a2402221ee25f58 (patch) | |
tree | 8372c01c23d94190463f02c5e643142995aa7f84 /common | |
parent | 3e5ca7931c68152e4ec18d126e9c832dd84914c8 (diff) |
speculative : threading options (#4959)
* speculative: expose draft threading
* fix usage format
* accept -td and -tbd args
* speculative: revert default behavior when -td is unspecified
* fix trailing whitespace
Diffstat (limited to 'common')
-rw-r--r-- | common/common.cpp | 22 | ||||
-rw-r--r-- | common/common.h | 2 |
2 files changed, 24 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp index c11006bc..2b0865ff 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -167,6 +167,24 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { if (params.n_threads_batch <= 0) { params.n_threads_batch = std::thread::hardware_concurrency(); } + } else if (arg == "-td" || arg == "--threads-draft") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads_draft = std::stoi(argv[i]); + if (params.n_threads_draft <= 0) { + params.n_threads_draft = std::thread::hardware_concurrency(); + } + } else if (arg == "-tbd" || arg == "--threads-batch-draft") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads_batch_draft = std::stoi(argv[i]); + if (params.n_threads_batch_draft <= 0) { + params.n_threads_batch_draft = std::thread::hardware_concurrency(); + } } else if (arg == "-p" || arg == "--prompt") { if (++i >= argc) { invalid_param = true; @@ -845,6 +863,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); printf(" -tb N, --threads-batch N\n"); printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); + printf(" -td N, --threads-draft N"); + printf(" number of threads to use during generation (default: same as --threads)"); + printf(" -tbd N, --threads-batch-draft N\n"); + printf(" number of threads to use during batch and prompt processing (default: same as --threads-draft)\n"); printf(" -p PROMPT, --prompt PROMPT\n"); printf(" prompt to start generation with (default: empty)\n"); printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); diff --git a/common/common.h b/common/common.h index 09646824..1f43e628 100644 --- a/common/common.h +++ b/common/common.h @@ -46,7 +46,9 @@ struct gpt_params { uint32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); + int32_t n_threads_draft = -1; int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) + int32_t n_threads_batch_draft = -1; int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) |