diff options
author | Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> | 2024-03-05 03:24:00 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-04 20:24:00 +0200 |
commit | 6d341ab6c53cd51f2921d986d0090cc8b049b39a (patch) | |
tree | f212b497e210c8c73fe52369f6bc81297c7b1dab /common/common.cpp | |
parent | 4ffcdce2ff877ebb683cd217ea38faf20faa5ffe (diff) |
speculative : implement stochastic speculative sampling (#5625)
* (WIP) Implement stochastic speculative decoding
* sample from residual distribution on draft accept failure
* fix #5657: force greedy sampling with probs when temp is 0
* remove p_accept parameter
* fix style
* remove unused variables
* add srand() in speculative.cpp
* replace use of rand() with mt19937 sampling
* fixes based on review (@JohannesGaessler)
* fix r random generation
* randomly select next sequence to verify + fix bug in memory freeing
* fix bug in active_seqs sync
* fix uniform int distribution initialization
* remove warnings from comparison between int and size_t
* check grammar in `llama_sample_probability_distribution_impl`
* remove malloc code by utilizing vectors
* add PR link to README
Diffstat (limited to 'common/common.cpp')
-rw-r--r-- | common/common.cpp | 7 |
1 files changed, 0 insertions, 7 deletions
diff --git a/common/common.cpp b/common/common.cpp index dbe7e922..036a9813 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -513,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.n_sequences = std::stoi(argv[i]); - } else if (arg == "--p-accept" || arg == "-pa") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.p_accept = std::stof(argv[i]); } else if (arg == "--p-split" || arg == "-ps") { if (++i >= argc) { invalid_param = true; @@ -1044,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); - printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); |