diff options
Diffstat (limited to 'examples/retrieval/retrieval.cpp')
-rw-r--r-- | examples/retrieval/retrieval.cpp | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 55b7b2f7..eb89d16d 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -73,9 +73,10 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz return chunks; } -static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) { - for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); +static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + llama_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -160,6 +161,12 @@ int main(int argc, char ** argv) { const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + return 1; + } + if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); |