diff options
author | Pavol Rusnak <pavol@rusnak.io> | 2023-03-29 21:31:24 +0200 |
---|---|---|
committer | Pavol Rusnak <pavol@rusnak.io> | 2023-03-31 10:32:01 +0200 |
commit | cbef542879962fdc491656cd0c8cadd65a5f1356 (patch) | |
tree | ba31f66c0613411466b31c822fb5bac2b24c910a /convert-ggml-to-pth.py | |
parent | 9733104be5389ebb1ff05095eca2a70280cd875a (diff) |
py : cleanup the code
- use f-strings where possible
- drop first param of encode/decode functions since "utf-8" is the default
Diffstat (limited to 'convert-ggml-to-pth.py')
-rw-r--r-- | convert-ggml-to-pth.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/convert-ggml-to-pth.py b/convert-ggml-to-pth.py index 8ab17410..7ddfe3a1 100644 --- a/convert-ggml-to-pth.py +++ b/convert-ggml-to-pth.py @@ -27,9 +27,9 @@ def read_tokens(fin, vocab_size): text_len = struct.unpack("i", fin.read(4))[0] text_bytes = fin.read(text_len) try: - text = text_bytes.decode("utf-8") + text = text_bytes.decode() except UnicodeDecodeError: - text = text_bytes.decode("utf-8", "replace") + text = text_bytes.decode(errors="replace") score = struct.unpack("f", fin.read(4))[0] tokens.append((text, score)) return tokens @@ -82,7 +82,7 @@ def read_variables(fin): shape = tuple(struct.unpack("i" * n_dims, fin.read(4 * n_dims))) shape = shape[::-1] - name = fin.read(name_length).decode("utf-8") + name = fin.read(name_length).decode() # ensure tensor data is aligned tensor_data_offset = fin.tell() @@ -199,7 +199,7 @@ def chat(model, hparams, llama_dir): device = torch.device("cpu") llama = llama.to(device) - ctx = """You are AI. + ctx = """You are AI. This is a dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, respectful, direct, concise, should try to protect User's privacy, and knows its own limits. Also, AI must answer User and AI cannot stop the conversation by itself. User: Hello, AI. AI: Hello! How can I assist you today? @@ -207,11 +207,11 @@ AI: Hello! How can I assist you today? print(ctx.rstrip("\n")) while True: print("-" * 60) - prompt = input(f"User: ") + prompt = input("User: ") if ctx != "": - ctx = ctx + "User: " + prompt + "\n" + ctx = f"{ctx}User: {prompt}\n" else: - ctx = prompt + "\nAI:" + ctx = f"{prompt}\nAI:" ctx = (ctx[-1920:]) if len(ctx) >= 2048 else ctx @@ -236,7 +236,7 @@ AI: Hello! How can I assist you today? ) s = generation_output.sequences[0] decoded = tokenizer.decode(s) - ctx = decoded + "\n" + ctx = f"{decoded}\n" def main(): |