summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
authorMarcus Dunn <51931484+MarcusDunn@users.noreply.github.com>2023-10-23 12:40:03 -0700
committerGitHub <noreply@github.com>2023-10-23 22:40:03 +0300
commit5be6c803fa5378f62a1590f3ad8c6b64c7c0c2ce (patch)
tree190868e0431070686d797c3c2d86da857b8ba55f /common/sampling.cpp
parent6336701c9378c23c85d1c0e464b663ca2bbb8e60 (diff)
llama : remove token functions with `context` args in favor of `model` (#3720)
* added `llama_model_token_*` variants to all the `llama_token_*` functions. * added `LLAMA_API` * formatting Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * removed old `llama_token` functions * changed 3 more functions to take in model - `llama_token_get_text` - `llama_token_get_score` - `llama_token_get_type` * added back docs * fixed main.cpp * changed token functions to use new model variants * changed token functions to use new model variants --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp4
1 files changed, 2 insertions, 2 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 6f0af3c4..5258d4e8 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -147,7 +147,7 @@ llama_token llama_sampling_sample(
// apply penalties
if (!prev.empty()) {
- const float nl_logit = logits[llama_token_nl(ctx_main)];
+ const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
llama_sample_repetition_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - penalty_last_n,
@@ -155,7 +155,7 @@ llama_token llama_sampling_sample(
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
- if (cur_p.data[idx].id == llama_token_nl(ctx_main)) {
+ if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit;
break;
}