From 23b9d3af4971a983b1d0a141cc6492e7be25daf5 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Mon, 2 Oct 2023 12:45:14 +0800 Subject: [PATCH] force oai endpoints to return json --- koboldcpp.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index 900be3c104447..406d0f16f1819 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -526,6 +526,7 @@ def do_GET(self): global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens self.path = self.path.rstrip('/') response_body = None + force_json = False if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without / if args.stream and not "streaming=1" in self.path: @@ -585,6 +586,7 @@ def do_GET(self): elif self.path.endswith('/v1/models') or self.path.endswith('/models'): response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode()) + force_json = True elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')): response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode()) @@ -598,7 +600,7 @@ def do_GET(self): else: self.send_response(200) self.send_header('Content-Length', str(len(response_body))) - self.end_headers() + self.end_headers(force_json=force_json) self.wfile.write(response_body) return @@ -607,6 +609,7 @@ def do_POST(self): content_length = int(self.headers['Content-Length']) body = self.rfile.read(content_length) self.path = self.path.rstrip('/') + force_json = False if self.path.endswith(('/api/extra/tokencount')): try: @@ -686,6 +689,7 @@ def do_POST(self): if self.path.endswith('/v1/completions') or self.path.endswith('/completions'): api_format = 3 + force_json = True if api_format>0: genparams = None @@ -707,7 +711,7 @@ def do_POST(self): # Headers are already sent when streaming if not kai_sse_stream_flag: self.send_response(200) - self.end_headers() + self.end_headers(force_json=force_json) self.wfile.write(json.dumps(gen).encode()) except: print("Generate: The response could not be sent, maybe connection was terminated?") @@ -728,11 +732,11 @@ def do_HEAD(self): self.send_response(200) self.end_headers() - def end_headers(self): + def end_headers(self, force_json=False): self.send_header('Access-Control-Allow-Origin', '*') self.send_header('Access-Control-Allow-Methods', '*') self.send_header('Access-Control-Allow-Headers', '*') - if "/api" in self.path: + if "/api" in self.path or force_json: if self.path.endswith("/stream"): self.send_header('Content-type', 'text/event-stream') self.send_header('Content-type', 'application/json')