summaryrefslogtreecommitdiff
path: root/examples/embd-input
diff options
context:
space:
mode:
Diffstat (limited to 'examples/embd-input')
-rw-r--r--examples/embd-input/embd-input-lib.cpp11
1 files changed, 6 insertions, 5 deletions
diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp
index c995eef3..9bd4d347 100644
--- a/examples/embd-input/embd-input-lib.cpp
+++ b/examples/embd-input/embd-input-lib.cpp
@@ -80,7 +80,8 @@ bool eval_float(void * model, float * input, int N){
if (n_eval > n_batch) {
n_eval = n_batch;
}
- if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
+ llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
+ if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
@@ -101,7 +102,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) {
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
- if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
@@ -183,11 +184,11 @@ llama_token sampling_id(struct MyModel* mymodel) {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
- llama_sample_temperature(ctx, &candidates_p, temp);
+ llama_sample_temp(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
- llama_sample_temperature(ctx, &candidates_p, temp);
+ llama_sample_temp(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
@@ -195,7 +196,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
- llama_sample_temperature(ctx, &candidates_p, temp);
+ llama_sample_temp(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
}