summaryrefslogtreecommitdiff
path: root/main.cpp
diff options
context:
space:
mode:
authorLuciano <lucianostrika44@gmail.com>2023-03-24 08:05:13 -0700
committerGitHub <noreply@github.com>2023-03-24 17:05:13 +0200
commit8d4a855c241ecb0f3ddc03447fe56002ebf27a37 (patch)
tree4de329fb2849fb6128d05237850b8ceb7519bf36 /main.cpp
parentb6b268d4415fd3b3e53f22b6619b724d4928f713 (diff)
Add embedding mode with arg flag. Currently working (#282)
* working but ugly * add arg flag, not working on embedding mode * typo * Working! Thanks to @nullhook * make params argument instead of hardcoded boolean. remove useless time check * start doing the instructions but not finished. This probably doesnt compile * Embeddings extraction support --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'main.cpp')
-rw-r--r--main.cpp23
1 files changed, 23 insertions, 0 deletions
diff --git a/main.cpp b/main.cpp
index 5ba6d5a7..46a80ff8 100644
--- a/main.cpp
+++ b/main.cpp
@@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity;
+ lparams.embedding = params.embedding;
ctx = llama_init_from_file(params.model.c_str(), lparams);
@@ -292,6 +293,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
+
int last_n_size = params.repeat_last_n;
std::vector<llama_token> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
@@ -324,6 +326,27 @@ int main(int argc, char ** argv) {
// the first thing we will do is to output the prompt, so set color accordingly
set_console_state(CONSOLE_STATE_PROMPT);
+ if (params.embedding){
+ embd = embd_inp;
+
+ if (embd.size() > 0) {
+ if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
+ fprintf(stderr, "%s : failed to eval\n", __func__);
+ return 1;
+ }
+ }
+
+ const auto embeddings = llama_get_embeddings(ctx);
+
+ // TODO: print / use the embeddings
+
+ if (params.use_color) {
+ printf(ANSI_COLOR_RESET);
+ }
+
+ return 0;
+ }
+
while (remaining_tokens > 0 || params.interactive) {
// predict
if (embd.size() > 0) {