summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-11-03 09:41:17 +0200
committerGeorgi Gerganov <ggerganov@gmail.com>2023-11-03 09:41:56 +0200
commit8f961abdc4e134c83bf8c2ad618ab256b4cae0f9 (patch)
tree8dd4776cfc627709436fb25fabe56385f9fab35d /common
parent05816027d649f977468fc804cdb54e99eac246d1 (diff)
speculative : change default p_accept to 0.5 + CLI args (#3919)
ggml-ci
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp14
-rw-r--r--common/common.h8
2 files changed, 20 insertions, 2 deletions
diff --git a/common/common.cpp b/common/common.cpp
index e938dee1..20cc4a08 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -403,6 +403,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_sequences = std::stoi(argv[i]);
+ } else if (arg == "--p-accept" || arg == "-pa") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.p_accept = std::stof(argv[i]);
+ } else if (arg == "--p-split" || arg == "-ps") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.p_split = std::stof(argv[i]);
} else if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_param = true;
@@ -778,6 +790,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
+ printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
+ printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
diff --git a/common/common.h b/common/common.h
index 9ad62563..dd6b002e 100644
--- a/common/common.h
+++ b/common/common.h
@@ -44,6 +44,7 @@ int32_t get_num_physical_cores();
struct gpt_params {
uint32_t seed = -1; // RNG seed
+
int32_t n_threads = get_num_physical_cores();
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_predict = -1; // new tokens to predict
@@ -54,6 +55,8 @@ struct gpt_params {
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
+ float p_accept = 0.5f; // speculative decoding accept probability
+ float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
@@ -66,7 +69,8 @@ struct gpt_params {
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
- int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
+ int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
+ // pinging @cebtenzzre
// // sampling parameters
struct llama_sampling_params sparams;
@@ -90,7 +94,7 @@ struct gpt_params {
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
// (which is more convenient to use for plotting)
//
- bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
+ bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS