summaryrefslogtreecommitdiff
path: root/examples/embd-input/embd_input.py
diff options
context:
space:
mode:
authorningshanwutuobang <ningshanwutuobang@gmail.com>2023-06-28 23:53:37 +0800
committerGitHub <noreply@github.com>2023-06-28 18:53:37 +0300
commitcfa0750bc9dbc2d957a91b8ed09ab0035d8f3d4e (patch)
treec8d6d6e6548d4f03899704f64bce6939e471e4e6 /examples/embd-input/embd_input.py
parent9d23589d638dc74577d5ff880e6d4248b795f12e (diff)
llama : support input embeddings directly (#1910)
* add interface for float input * fixed inpL shape and type * add examples of input floats * add test example for embd input * fixed sampling * add free for context * fixed add end condition for generating * add examples for llava.py * add READMD for llava.py * add READMD for llava.py * add example of PandaGPT * refactor the interface and fixed the styles * add cmake build for embd-input * add cmake build for embd-input * Add MiniGPT-4 example * change the order of the args of llama_eval_internal * fix ci error
Diffstat (limited to 'examples/embd-input/embd_input.py')
-rw-r--r--examples/embd-input/embd_input.py71
1 files changed, 71 insertions, 0 deletions
diff --git a/examples/embd-input/embd_input.py b/examples/embd-input/embd_input.py
new file mode 100644
index 00000000..be289661
--- /dev/null
+++ b/examples/embd-input/embd_input.py
@@ -0,0 +1,71 @@
+import ctypes
+from ctypes import cdll, c_char_p, c_void_p, POINTER, c_float, c_int
+import numpy as np
+import os
+
+libc = cdll.LoadLibrary("./libembdinput.so")
+libc.sampling.restype=c_char_p
+libc.create_mymodel.restype=c_void_p
+libc.eval_string.argtypes=[c_void_p, c_char_p]
+libc.sampling.argtypes=[c_void_p]
+libc.eval_float.argtypes=[c_void_p, POINTER(c_float), c_int]
+
+
+class MyModel:
+ def __init__(self, args):
+ argc = len(args)
+ c_str = [c_char_p(i.encode()) for i in args]
+ args_c = (c_char_p * argc)(*c_str)
+ self.model = c_void_p(libc.create_mymodel(argc, args_c))
+ self.max_tgt_len = 512
+ self.print_string_eval = True
+
+ def __del__(self):
+ libc.free_mymodel(self.model)
+
+ def eval_float(self, x):
+ libc.eval_float(self.model, x.astype(np.float32).ctypes.data_as(POINTER(c_float)), x.shape[1])
+
+ def eval_string(self, x):
+ libc.eval_string(self.model, x.encode()) # c_char_p(x.encode()))
+ if self.print_string_eval:
+ print(x)
+
+ def eval_token(self, x):
+ libc.eval_id(self.model, x)
+
+ def sampling(self):
+ s = libc.sampling(self.model)
+ return s
+
+ def stream_generate(self, end="</s>"):
+ ret = b""
+ end = end.encode()
+ for _ in range(self.max_tgt_len):
+ tmp = self.sampling()
+ ret += tmp
+ yield tmp
+ if ret.endswith(end):
+ break
+
+ def generate_with_print(self, end="</s>"):
+ ret = b""
+ for i in self.stream_generate(end=end):
+ ret += i
+ print(i.decode(errors="replace"), end="", flush=True)
+ print("")
+ return ret.decode(errors="replace")
+
+
+ def generate(self, end="</s>"):
+ text = b"".join(self.stream_generate(end=end))
+ return text.decode(errors="replace")
+
+if __name__ == "__main__":
+ model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
+ model.eval_string("""user: what is the color of the flag of UN?""")
+ x = np.random.random((5120,10))# , dtype=np.float32)
+ model.eval_float(x)
+ model.eval_string("""assistant:""")
+ for i in model.generate():
+ print(i.decode(errors="replace"), end="", flush=True)