diff options
author | Olivier Chafik <ochafik@users.noreply.github.com> | 2024-04-11 19:47:34 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-11 19:47:34 +0100 |
commit | cbaadc92942c50aab599a9e4c163afc1f44f7c26 (patch) | |
tree | 0a4b962430740a81a6b1789f1edd9ee50074dde3 /llama.cpp | |
parent | 1bbdaf6ecda6f0a360dfb307b256fcb6838c560b (diff) |
grammars: 1.5x faster inference w/ complex grammars (vector reserves / reuses) (#6609)
* grammars: reserve rejects & next candidates
* grammars: reuse new_stacks
* grammars: fix missing sig change in llama.h
* grammars: fix test (api changed)
* grammars: update gbnf-validator.cpp
* grammars: simpler syntax (no swap)
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 16 |
1 files changed, 10 insertions, 6 deletions
@@ -11912,12 +11912,13 @@ static void llama_grammar_advance_stack( // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept( +void llama_grammar_accept( const std::vector<std::vector<llama_grammar_element>> & rules, const std::vector<std::vector<const llama_grammar_element *>> & stacks, - const uint32_t chr) { + const uint32_t chr, + std::vector<std::vector<const llama_grammar_element *>> & new_stacks) { - std::vector<std::vector<const llama_grammar_element *>> new_stacks; + new_stacks.clear(); for (const auto & stack : stacks) { if (stack.empty()) { @@ -11936,8 +11937,6 @@ std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept( llama_grammar_advance_stack(rules, new_stack, new_stacks); } } - - return new_stacks; } static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates( @@ -11951,6 +11950,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_ const std::vector<llama_grammar_candidate> & candidates) { std::vector<llama_grammar_candidate> rejects; + rejects.reserve(candidates.size()); if (stack.empty()) { for (const auto & tok : candidates) { @@ -11964,6 +11964,8 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_ const llama_grammar_element * stack_pos = stack.back(); std::vector<llama_grammar_candidate> next_candidates; + next_candidates.reserve(candidates.size()); + for (const auto & tok : candidates) { if (*tok.code_points == 0) { // reached end of full codepoints in token, reject iff it ended in a partial sequence @@ -12771,8 +12773,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar->partial_utf8); const auto & code_points = decoded.first; + std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks); + grammar->stacks = tmp_new_stacks; } grammar->partial_utf8 = decoded.second; GGML_ASSERT(!grammar->stacks.empty()); |