summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
authorMinsoo Cheong <54794500+mscheong01@users.noreply.github.com>2024-03-05 03:24:00 +0900
committerGitHub <noreply@github.com>2024-03-04 20:24:00 +0200
commit6d341ab6c53cd51f2921d986d0090cc8b049b39a (patch)
treef212b497e210c8c73fe52369f6bc81297c7b1dab /common/common.cpp
parent4ffcdce2ff877ebb683cd217ea38faf20faa5ffe (diff)
speculative : implement stochastic speculative sampling (#5625)
* (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix #5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp7
1 files changed, 0 insertions, 7 deletions
diff --git a/common/common.cpp b/common/common.cpp
index dbe7e922..036a9813 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -513,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_sequences = std::stoi(argv[i]);
- } else if (arg == "--p-accept" || arg == "-pa") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params.p_accept = std::stof(argv[i]);
} else if (arg == "--p-split" || arg == "-ps") {
if (++i >= argc) {
invalid_param = true;
@@ -1044,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
- printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");