summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp22
-rw-r--r--common/common.h2
2 files changed, 24 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index c11006bc..2b0865ff 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -167,6 +167,24 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.n_threads_batch <= 0) {
params.n_threads_batch = std::thread::hardware_concurrency();
}
+ } else if (arg == "-td" || arg == "--threads-draft") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_threads_draft = std::stoi(argv[i]);
+ if (params.n_threads_draft <= 0) {
+ params.n_threads_draft = std::thread::hardware_concurrency();
+ }
+ } else if (arg == "-tbd" || arg == "--threads-batch-draft") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_threads_batch_draft = std::stoi(argv[i]);
+ if (params.n_threads_batch_draft <= 0) {
+ params.n_threads_batch_draft = std::thread::hardware_concurrency();
+ }
} else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) {
invalid_param = true;
@@ -845,6 +863,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N\n");
printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
+ printf(" -td N, --threads-draft N");
+ printf(" number of threads to use during generation (default: same as --threads)");
+ printf(" -tbd N, --threads-batch-draft N\n");
+ printf(" number of threads to use during batch and prompt processing (default: same as --threads-draft)\n");
printf(" -p PROMPT, --prompt PROMPT\n");
printf(" prompt to start generation with (default: empty)\n");
printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
diff --git a/common/common.h b/common/common.h
index 09646824..1f43e628 100644
--- a/common/common.h
+++ b/common/common.h
@@ -46,7 +46,9 @@ struct gpt_params {
uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores();
+ int32_t n_threads_draft = -1;
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
+ int32_t n_threads_batch_draft = -1;
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)