api_like_OAI.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. #!/usr/bin/env python3
  2. import argparse
  3. from flask import Flask, jsonify, request, Response
  4. import urllib.parse
  5. import requests
  6. import time
  7. import json
  8. app = Flask(__name__)
  9. slot_id = -1
  10. parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
  11. parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')
  12. parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: 'USER: ')", default="USER: ")
  13. parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: 'ASSISTANT: ')", default="ASSISTANT: ")
  14. parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: 'ASSISTANT's RULE: ')", default="ASSISTANT's RULE: ")
  15. parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
  16. parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
  17. parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
  18. parser.add_argument("--host", type=str, help="Set the ip address to listen.(default: 127.0.0.1)", default='127.0.0.1')
  19. parser.add_argument("--port", type=int, help="Set the port to listen.(default: 8081)", default=8081)
  20. args = parser.parse_args()
  21. def is_present(json, key):
  22. try:
  23. buf = json[key]
  24. except KeyError:
  25. return False
  26. if json[key] == None:
  27. return False
  28. return True
  29. #convert chat to prompt
  30. def convert_chat(messages):
  31. system_n = args.system_name
  32. user_n = args.user_name
  33. ai_n = args.ai_name
  34. stop = args.stop
  35. prompt = "" + args.chat_prompt + stop
  36. for line in messages:
  37. if (line["role"] == "system"):
  38. prompt += f"{system_n}{line['content']}{stop}"
  39. if (line["role"] == "user"):
  40. prompt += f"{user_n}{line['content']}{stop}"
  41. if (line["role"] == "assistant"):
  42. prompt += f"{ai_n}{line['content']}{stop}"
  43. prompt += ai_n.rstrip()
  44. return prompt
  45. def make_postData(body, chat=False, stream=False):
  46. postData = {}
  47. if (chat):
  48. postData["prompt"] = convert_chat(body["messages"])
  49. else:
  50. postData["prompt"] = body["prompt"]
  51. if(is_present(body, "temperature")): postData["temperature"] = body["temperature"]
  52. if(is_present(body, "top_k")): postData["top_k"] = body["top_k"]
  53. if(is_present(body, "top_p")): postData["top_p"] = body["top_p"]
  54. if(is_present(body, "max_tokens")): postData["n_predict"] = body["max_tokens"]
  55. if(is_present(body, "presence_penalty")): postData["presence_penalty"] = body["presence_penalty"]
  56. if(is_present(body, "frequency_penalty")): postData["frequency_penalty"] = body["frequency_penalty"]
  57. if(is_present(body, "repeat_penalty")): postData["repeat_penalty"] = body["repeat_penalty"]
  58. if(is_present(body, "mirostat")): postData["mirostat"] = body["mirostat"]
  59. if(is_present(body, "mirostat_tau")): postData["mirostat_tau"] = body["mirostat_tau"]
  60. if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"]
  61. if(is_present(body, "seed")): postData["seed"] = body["seed"]
  62. if(is_present(body, "grammar")): postData["grammar"] = body["grammar"]
  63. if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()]
  64. if (args.stop != ""):
  65. postData["stop"] = [args.stop]
  66. else:
  67. postData["stop"] = []
  68. if(is_present(body, "stop")): postData["stop"] += body["stop"]
  69. postData["n_keep"] = -1
  70. postData["stream"] = stream
  71. postData["cache_prompt"] = True
  72. postData["slot_id"] = slot_id
  73. return postData
  74. def make_resData(data, chat=False, promptToken=[]):
  75. resData = {
  76. "id": "chatcmpl" if (chat) else "cmpl",
  77. "object": "chat.completion" if (chat) else "text_completion",
  78. "created": int(time.time()),
  79. "truncated": data["truncated"],
  80. "model": "LLaMA_CPP",
  81. "usage": {
  82. "prompt_tokens": data["tokens_evaluated"],
  83. "completion_tokens": data["tokens_predicted"],
  84. "total_tokens": data["tokens_evaluated"] + data["tokens_predicted"]
  85. }
  86. }
  87. if (len(promptToken) != 0):
  88. resData["promptToken"] = promptToken
  89. if (chat):
  90. #only one choice is supported
  91. resData["choices"] = [{
  92. "index": 0,
  93. "message": {
  94. "role": "assistant",
  95. "content": data["content"],
  96. },
  97. "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
  98. }]
  99. else:
  100. #only one choice is supported
  101. resData["choices"] = [{
  102. "text": data["content"],
  103. "index": 0,
  104. "logprobs": None,
  105. "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
  106. }]
  107. return resData
  108. def make_resData_stream(data, chat=False, time_now = 0, start=False):
  109. resData = {
  110. "id": "chatcmpl" if (chat) else "cmpl",
  111. "object": "chat.completion.chunk" if (chat) else "text_completion.chunk",
  112. "created": time_now,
  113. "model": "LLaMA_CPP",
  114. "choices": [
  115. {
  116. "finish_reason": None,
  117. "index": 0
  118. }
  119. ]
  120. }
  121. slot_id = data.get("slot_id")
  122. if (chat):
  123. if (start):
  124. resData["choices"][0]["delta"] = {
  125. "role": "assistant"
  126. }
  127. else:
  128. resData["choices"][0]["delta"] = {
  129. "content": data["content"]
  130. }
  131. if (data["stop"]):
  132. resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
  133. else:
  134. resData["choices"][0]["text"] = data["content"]
  135. if (data["stop"]):
  136. resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
  137. return resData
  138. @app.route('/chat/completions', methods=['POST', 'OPTIONS'])
  139. @app.route('/v1/chat/completions', methods=['POST', 'OPTIONS'])
  140. def chat_completions():
  141. if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
  142. return Response(status=403)
  143. if request.method == 'OPTIONS':
  144. return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
  145. body = request.get_json()
  146. stream = False
  147. tokenize = False
  148. if(is_present(body, "stream")): stream = body["stream"]
  149. if(is_present(body, "tokenize")): tokenize = body["tokenize"]
  150. postData = make_postData(body, chat=True, stream=stream)
  151. promptToken = []
  152. if (tokenize):
  153. tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
  154. promptToken = tokenData["tokens"]
  155. if (not stream):
  156. data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
  157. print(data.json())
  158. resData = make_resData(data.json(), chat=True, promptToken=promptToken)
  159. return jsonify(resData)
  160. else:
  161. def generate():
  162. data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
  163. time_now = int(time.time())
  164. resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
  165. yield 'data: {}\n\n'.format(json.dumps(resData))
  166. for line in data.iter_lines():
  167. if line:
  168. decoded_line = line.decode('utf-8')
  169. resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
  170. yield 'data: {}\n\n'.format(json.dumps(resData))
  171. return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
  172. @app.route('/completions', methods=['POST', 'OPTIONS'])
  173. @app.route('/v1/completions', methods=['POST', 'OPTIONS'])
  174. def completion():
  175. if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
  176. return Response(status=403)
  177. if request.method == 'OPTIONS':
  178. return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
  179. body = request.get_json()
  180. stream = False
  181. tokenize = False
  182. if(is_present(body, "stream")): stream = body["stream"]
  183. if(is_present(body, "tokenize")): tokenize = body["tokenize"]
  184. postData = make_postData(body, chat=False, stream=stream)
  185. promptToken = []
  186. if (tokenize):
  187. tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
  188. promptToken = tokenData["tokens"]
  189. if (not stream):
  190. data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
  191. print(data.json())
  192. resData = make_resData(data.json(), chat=False, promptToken=promptToken)
  193. return jsonify(resData)
  194. else:
  195. def generate():
  196. data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
  197. time_now = int(time.time())
  198. for line in data.iter_lines():
  199. if line:
  200. decoded_line = line.decode('utf-8')
  201. resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
  202. yield 'data: {}\n\n'.format(json.dumps(resData))
  203. return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
  204. if __name__ == '__main__':
  205. app.run(args.host, port=args.port)