summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorrhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com>2023-11-30 20:50:40 +0000
committerGitHub <noreply@github.com>2023-11-30 22:50:40 +0200
commite2bd725f4b39bc5c6234858d158e01248f5ab5bd (patch)
treefa8fe8fe3d867f3d36d719d8a4dbeea5183f7412
parent1f5cd83275fabb43f2ae92c30033b384a3eb37b4 (diff)
py : fix oai proxy (#3972)
* fix oai proxy fix generation not stoped while bot stop talking in chat mode fix possible `slot_id` not exist response for cors (and pre flight) * oai proxy: workaround for some client (such as Chatbox) * use stop as separator to replace hardcoded `\n`
-rwxr-xr-xexamples/server/api_like_OAI.py46
1 files changed, 25 insertions, 21 deletions
diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py
index 313e1a96..830c056d 100755
--- a/examples/server/api_like_OAI.py
+++ b/examples/server/api_like_OAI.py
@@ -11,10 +11,10 @@ app = Flask(__name__)
slot_id = -1
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
-parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
-parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
-parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
-parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ")
+parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')
+parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: 'USER: ')", default="USER: ")
+parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: 'ASSISTANT: ')", default="ASSISTANT: ")
+parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: 'ASSISTANT's RULE: ')", default="ASSISTANT's RULE: ")
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
@@ -34,19 +34,19 @@ def is_present(json, key):
#convert chat to prompt
def convert_chat(messages):
- prompt = "" + args.chat_prompt.replace("\\n", "\n")
- system_n = args.system_name.replace("\\n", "\n")
- user_n = args.user_name.replace("\\n", "\n")
- ai_n = args.ai_name.replace("\\n", "\n")
- stop = args.stop.replace("\\n", "\n")
+ system_n = args.system_name
+ user_n = args.user_name
+ ai_n = args.ai_name
+ stop = args.stop
+ prompt = "" + args.chat_prompt + stop
for line in messages:
if (line["role"] == "system"):
- prompt += f"{system_n}{line['content']}"
+ prompt += f"{system_n}{line['content']}{stop}"
if (line["role"] == "user"):
- prompt += f"{user_n}{line['content']}"
+ prompt += f"{user_n}{line['content']}{stop}"
if (line["role"] == "assistant"):
prompt += f"{ai_n}{line['content']}{stop}"
prompt += ai_n.rstrip()
@@ -130,7 +130,7 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
}
]
}
- slot_id = data["slot_id"]
+ slot_id = data.get("slot_id")
if (chat):
if (start):
resData["choices"][0]["delta"] = {
@@ -150,11 +150,13 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
return resData
-@app.route('/chat/completions', methods=['POST'])
-@app.route('/v1/chat/completions', methods=['POST'])
+@app.route('/chat/completions', methods=['POST', 'OPTIONS'])
+@app.route('/v1/chat/completions', methods=['POST', 'OPTIONS'])
def chat_completions():
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
return Response(status=403)
+ if request.method == 'OPTIONS':
+ return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
body = request.get_json()
stream = False
tokenize = False
@@ -177,20 +179,22 @@ def chat_completions():
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
time_now = int(time.time())
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
- yield 'data: {}\n'.format(json.dumps(resData))
+ yield 'data: {}\n\n'.format(json.dumps(resData))
for line in data.iter_lines():
if line:
decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
- yield 'data: {}\n'.format(json.dumps(resData))
- return Response(generate(), mimetype='text/event-stream')
+ yield 'data: {}\n\n'.format(json.dumps(resData))
+ return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
-@app.route('/completions', methods=['POST'])
-@app.route('/v1/completions', methods=['POST'])
+@app.route('/completions', methods=['POST', 'OPTIONS'])
+@app.route('/v1/completions', methods=['POST', 'OPTIONS'])
def completion():
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
return Response(status=403)
+ if request.method == 'OPTIONS':
+ return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
body = request.get_json()
stream = False
tokenize = False
@@ -216,8 +220,8 @@ def completion():
if line:
decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
- yield 'data: {}\n'.format(json.dumps(resData))
- return Response(generate(), mimetype='text/event-stream')
+ yield 'data: {}\n\n'.format(json.dumps(resData))
+ return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
if __name__ == '__main__':
app.run(args.host, port=args.port)