summaryrefslogtreecommitdiff
path: root/examples/embd-input/embd-input-test.cpp
diff options
context:
space:
mode:
authorningshanwutuobang <ningshanwutuobang@gmail.com>2023-06-28 23:53:37 +0800
committerGitHub <noreply@github.com>2023-06-28 18:53:37 +0300
commitcfa0750bc9dbc2d957a91b8ed09ab0035d8f3d4e (patch)
treec8d6d6e6548d4f03899704f64bce6939e471e4e6 /examples/embd-input/embd-input-test.cpp
parent9d23589d638dc74577d5ff880e6d4248b795f12e (diff)
llama : support input embeddings directly (#1910)
* add interface for float input * fixed inpL shape and type * add examples of input floats * add test example for embd input * fixed sampling * add free for context * fixed add end condition for generating * add examples for llava.py * add READMD for llava.py * add READMD for llava.py * add example of PandaGPT * refactor the interface and fixed the styles * add cmake build for embd-input * add cmake build for embd-input * Add MiniGPT-4 example * change the order of the args of llama_eval_internal * fix ci error
Diffstat (limited to 'examples/embd-input/embd-input-test.cpp')
-rw-r--r--examples/embd-input/embd-input-test.cpp35
1 files changed, 35 insertions, 0 deletions
diff --git a/examples/embd-input/embd-input-test.cpp b/examples/embd-input/embd-input-test.cpp
new file mode 100644
index 00000000..e5e040f6
--- /dev/null
+++ b/examples/embd-input/embd-input-test.cpp
@@ -0,0 +1,35 @@
+#include "embd-input.h"
+#include <stdlib.h>
+#include <random>
+#include <string.h>
+
+int main(int argc, char** argv) {
+
+ auto mymodel = create_mymodel(argc, argv);
+ int N = 10;
+ int max_tgt_len = 500;
+ int n_embd = llama_n_embd(mymodel->ctx);
+
+ // add random float embd to test evaluation
+ float * data = new float[N*n_embd];
+ std::default_random_engine e;
+ std::uniform_real_distribution<float> u(0,1);
+ for (int i=0;i<N*n_embd;i++) {
+ data[i] = u(e);
+ }
+
+ eval_string(mymodel, "user: what is the color of the flag of UN?");
+ eval_float(mymodel, data, N);
+ eval_string(mymodel, "assistant:");
+ eval_string(mymodel, mymodel->params.prompt.c_str());
+ const char* tmp;
+ for (int i=0; i<max_tgt_len; i++) {
+ tmp = sampling(mymodel);
+ if (strcmp(tmp, "</s>")==0) break;
+ printf("%s", tmp);
+ fflush(stdout);
+ }
+ printf("\n");
+ free_mymodel(mymodel);
+ return 0;
+}