diff options
author | ningshanwutuobang <ningshanwutuobang@gmail.com> | 2023-06-28 23:53:37 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-28 18:53:37 +0300 |
commit | cfa0750bc9dbc2d957a91b8ed09ab0035d8f3d4e (patch) | |
tree | c8d6d6e6548d4f03899704f64bce6939e471e4e6 /examples/embd-input/embd-input-test.cpp | |
parent | 9d23589d638dc74577d5ff880e6d4248b795f12e (diff) |
llama : support input embeddings directly (#1910)
* add interface for float input
* fixed inpL shape and type
* add examples of input floats
* add test example for embd input
* fixed sampling
* add free for context
* fixed add end condition for generating
* add examples for llava.py
* add READMD for llava.py
* add READMD for llava.py
* add example of PandaGPT
* refactor the interface and fixed the styles
* add cmake build for embd-input
* add cmake build for embd-input
* Add MiniGPT-4 example
* change the order of the args of llama_eval_internal
* fix ci error
Diffstat (limited to 'examples/embd-input/embd-input-test.cpp')
-rw-r--r-- | examples/embd-input/embd-input-test.cpp | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/examples/embd-input/embd-input-test.cpp b/examples/embd-input/embd-input-test.cpp new file mode 100644 index 00000000..e5e040f6 --- /dev/null +++ b/examples/embd-input/embd-input-test.cpp @@ -0,0 +1,35 @@ +#include "embd-input.h" +#include <stdlib.h> +#include <random> +#include <string.h> + +int main(int argc, char** argv) { + + auto mymodel = create_mymodel(argc, argv); + int N = 10; + int max_tgt_len = 500; + int n_embd = llama_n_embd(mymodel->ctx); + + // add random float embd to test evaluation + float * data = new float[N*n_embd]; + std::default_random_engine e; + std::uniform_real_distribution<float> u(0,1); + for (int i=0;i<N*n_embd;i++) { + data[i] = u(e); + } + + eval_string(mymodel, "user: what is the color of the flag of UN?"); + eval_float(mymodel, data, N); + eval_string(mymodel, "assistant:"); + eval_string(mymodel, mymodel->params.prompt.c_str()); + const char* tmp; + for (int i=0; i<max_tgt_len; i++) { + tmp = sampling(mymodel); + if (strcmp(tmp, "</s>")==0) break; + printf("%s", tmp); + fflush(stdout); + } + printf("\n"); + free_mymodel(mymodel); + return 0; +} |