summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorMarcus Dunn <51931484+MarcusDunn@users.noreply.github.com>2023-10-23 12:40:03 -0700
committerGitHub <noreply@github.com>2023-10-23 22:40:03 +0300
commit5be6c803fa5378f62a1590f3ad8c6b64c7c0c2ce (patch)
tree190868e0431070686d797c3c2d86da857b8ba55f /common
parent6336701c9378c23c85d1c0e464b663ca2bbb8e60 (diff)
llama : remove token functions with `context` args in favor of `model` (#3720)
* added `llama_model_token_*` variants to all the `llama_token_*` functions. * added `LLAMA_API` * formatting Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * removed old `llama_token` functions * changed 3 more functions to take in model - `llama_token_get_text` - `llama_token_get_score` - `llama_token_get_type` * added back docs * fixed main.cpp * changed token functions to use new model variants * changed token functions to use new model variants --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp8
-rw-r--r--common/sampling.cpp4
-rw-r--r--common/train.cpp6
3 files changed, 9 insertions, 9 deletions
diff --git a/common/common.cpp b/common/common.cpp
index bbd1518c..44bb7661 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -880,13 +880,13 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}
if (params.ignore_eos) {
- params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY;
+ params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
}
{
LOG("warming up the model with an empty run\n");
- std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
+ std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_tokens_rm(lctx, -1, -1);
llama_reset_timings(lctx);
@@ -941,7 +941,7 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
}
std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) {
- const llama_token bos_id = llama_token_bos(ctx);
+ const llama_token bos_id = llama_token_bos(llama_get_model(ctx));
std::string piece;
std::string result;
@@ -1186,7 +1186,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
- const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(lctx));
+ const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 6f0af3c4..5258d4e8 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -147,7 +147,7 @@ llama_token llama_sampling_sample(
// apply penalties
if (!prev.empty()) {
- const float nl_logit = logits[llama_token_nl(ctx_main)];
+ const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
llama_sample_repetition_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - penalty_last_n,
@@ -155,7 +155,7 @@ llama_token llama_sampling_sample(
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
- if (cur_p.data[idx].id == llama_token_nl(ctx_main)) {
+ if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit;
break;
}
diff --git a/common/train.cpp b/common/train.cpp
index 154ca56e..3cce5da2 100644
--- a/common/train.cpp
+++ b/common/train.cpp
@@ -236,8 +236,8 @@ int64_t get_example_targets_batch(
int64_t used_samples = 0;
ggml_set_f32(target_probs, 0.0f);
- llama_token bos = llama_token_bos(lctx);
- llama_token eos = llama_token_eos(lctx);
+ llama_token bos = llama_token_bos(llama_get_model(lctx));
+ llama_token eos = llama_token_eos(llama_get_model(lctx));
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
for (int k=0; k<n_batch; ++k) {
// printf("%s: batch %d\n", __func__, k);
@@ -924,7 +924,7 @@ size_t tokenize_file(
for (llama_token token=0; token < n_vocab; ++token) {
max_token_text_size = std::max(
max_token_text_size,
- strlen(llama_token_get_text(lctx, token)));
+ strlen(llama_token_get_text(llama_get_model(lctx), token)));
}
// upper bound of context byte length.