Skip to content

Commit

Permalink
Merge pull request #10 from vtuber-plan/development
Browse files Browse the repository at this point in the history
Fix bug
  • Loading branch information
FrostMiKu authored Jun 21, 2023
2 parents 7428edf + b4ee024 commit 7ff7fbe
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 66 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ python -m langport.service.gateway.openai_api
Run text generation with ggml worker:

```bash
python -m langport.service.server.generation_worker --port 21001 --model-path <your model path> --gpu-layers <num layer to gpu (resize this for your VRAM)>
python -m langport.service.server.ggml_generation_worker --port 21001 --model-path <your model path> --gpu-layers <num layer to gpu (resize this for your VRAM)>
```

## License
Expand Down
71 changes: 47 additions & 24 deletions langport/model/executor/generation/ggml.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from typing import List, Optional
from langport.model.executor.ggml import GgmlExecutor, GgmlTokenizer
from ctransformers import LLM
Expand Down Expand Up @@ -30,29 +31,30 @@ def stream_generation(
top_k = 40 if task.top_k <= 1 else task.top_k
repetition_penalty = 1.17647 if task.repetition_penalty == 0.0 else task.repetition_penalty

for j, token in enumerate(model.generate(tokens, top_k=top_k, top_p=task.top_p,
temperature=task.temperature, repetition_penalty=repetition_penalty)):
finish_reason = "stop"
n_tokens = 0
for token in model.generate(
tokens, top_k=top_k, top_p=task.top_p, batch_size=512,
temperature=task.temperature, repetition_penalty=repetition_penalty):
n_tokens += 1
output_ids.append(token)
if tokenizer.is_eos_token(token) or prompt_length + j == task.max_tokens - 1:
if n_tokens == task.max_tokens:
output = tokenizer.decode(output_ids)
if tokenizer.is_eos_token(token):
finish_reason = "stop"
else:
finish_reason = "length"
finish_reason = "length"
yield GenerationWorkerResult(
task_id=task.task_id,
type="finish",
text=output,
usage=UsageInfo(
prompt_tokens=prompt_length,
total_tokens=prompt_length + j,
completion_tokens=j,
total_tokens=prompt_length + n_tokens,
completion_tokens=n_tokens,
),
finish_reason=finish_reason,
)
break

if j%stream_interval!=0:
if n_tokens % stream_interval != 0:
continue
output = tokenizer.decode(output_ids)

Expand All @@ -63,12 +65,27 @@ def stream_generation(
text=output,
usage=UsageInfo(
prompt_tokens=prompt_length,
total_tokens=prompt_length + j,
completion_tokens=j,
total_tokens=prompt_length + n_tokens,
completion_tokens=n_tokens,
),
finish_reason=None,
)

# token == eos is checked in model.generate
if finish_reason == "stop":
output = tokenizer.decode(output_ids)
yield GenerationWorkerResult(
task_id=task.task_id,
type="finish",
text=output,
usage=UsageInfo(
prompt_tokens=prompt_length,
total_tokens=prompt_length + n_tokens,
completion_tokens=n_tokens,
),
finish_reason="stop",
)


class GgmlGenerationExecutor(GgmlExecutor):
def __init__(
Expand All @@ -93,6 +110,7 @@ def __init__(
)
self.n_ctx = context_length
self.adapter, self.model, self.tokenizer = self.load_model(model_path, from_pretrained_kwargs={})
self.lock = threading.Lock()

@property
def context_length(self) -> int:
Expand All @@ -109,18 +127,23 @@ def inference(self, worker: "GenerationModelWorker"):
batch_size = len(tasks)
if batch_size == 0:
return

self.lock.acquire()

# batch inference
for chunk in stream_generation(
self.model,
self.tokenizer,
worker.stream_interval,
tasks,
):
worker.push_task_result(chunk.task_id, chunk)

for task in tasks:
worker.push_task_result(
task.task_id, BaseWorkerResult(task_id=task.task_id, type="done")
)
try:
for chunk in stream_generation(
self.model,
self.tokenizer,
worker.stream_interval,
tasks,
):
worker.push_task_result(chunk.task_id, chunk)

for task in tasks:
worker.push_task_result(
task.task_id, BaseWorkerResult(task_id=task.task_id, type="done")
)
finally:
self.lock.release()

13 changes: 9 additions & 4 deletions langport/model/executor/generation/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def update_new_token(self, batch_token: List[int]):
self.set_stop(i)
if self.tasks[i].stop_token_ids is not None and token in self.tasks[i].stop_token_ids:
self.set_stop(i)
if self.get_prompt_length(i) + self.get_generated_length(i) >= self.max_tokens[i]:
if self.get_generated_length(i) == self.max_tokens[i]:
self.set_stop(i)

def set_stop(self, idx:int):
Expand Down Expand Up @@ -260,7 +260,11 @@ def generate(self, inputs: BatchingTask,

if all(inputs.stop):
break


# stop all
for i in range(inputs.batch_size):
if not inputs.is_stop(i):
inputs.set_stop(i)
if streamer:
streamer.end()

Expand Down Expand Up @@ -292,7 +296,7 @@ def put(self, value):
prompt_len = self.task_batch.get_prompt_length(i)

if self.task_batch.is_stop(i):
if prompt_len + generated_len >= self.task_batch.max_tokens[i]:
if generated_len == self.task_batch.max_tokens[i]:
finish_reason = "length"
else:
finish_reason = "stop"
Expand Down Expand Up @@ -329,7 +333,8 @@ def put(self, value):
)

def end(self):
pass
# check all done
self.put(None)


def stop_by_stopwords(
Expand Down
50 changes: 15 additions & 35 deletions langport/service/gateway/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,36 +44,16 @@ async def dispatch(self, request, call_next):
)
return await call_next(request)


class RedirectModelMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, redirect_rules:list, dispatch: Optional[DispatchFunction] = None) -> None:
super().__init__(app, dispatch)
self.redirect_rules = redirect_rules
self.receive_ = None

async def dispatch(self, request, call_next):
if "content-type" not in request.headers or request.headers["content-type"] != "application/json":
return await call_next(request)

try:
await self.set_body(request)
data = await request.json()
if "model" in data:
for rule in self.redirect_rules:
from_model_name, to_model_name = rule.split(":")
if data["model"] == from_model_name:
data["model"] = to_model_name
self.receive_['body'] = json.dumps(data).encode("utf-8")
break
except Exception as e:
logger.error(f"RedirectModelMiddleware: {e}")
return await call_next(request)

async def set_body(self, request):
self.receive_ = await request._receive()
async def receive():
return self.receive_
request._receive = receive
redirect_rules = None
def redirect_model_name(model:str):
if redirect_rules is not None:
for rule in redirect_rules:
from_model_name, to_model_name = rule.split(":")
if model == from_model_name:
logger.debug(f"Redirect model {from_model_name} to {to_model_name}")
model = to_model_name
break
return model


@app.exception_handler(RequestValidationError)
Expand All @@ -88,15 +68,18 @@ async def models():

@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
request.model = redirect_model_name(request.model)
return await api_chat_completions(app.app_settings, request)

@app.post("/v1/completions")
async def completions(request: CompletionRequest):
request.model = redirect_model_name(request.model)
return await api_completions(app.app_settings, request)


@app.post("/v1/embeddings")
async def embeddings(request: EmbeddingsRequest):
request.model = redirect_model_name(request.model)
return await api_embeddings(app.app_settings, request)


Expand Down Expand Up @@ -133,10 +116,7 @@ async def embeddings(request: EmbeddingsRequest):
allow_headers=args.allowed_headers,
)
if args.redirect is not None:
app.add_middleware(
RedirectModelMiddleware,
redirect_rules=args.redirect,
)
redirect_rules = args.redirect
if args.sk is not None:
app.add_middleware(
BaseAuthorizationMiddleware,
Expand All @@ -153,5 +133,5 @@ async def embeddings(request: EmbeddingsRequest):
host=args.host,
port=args.port,
log_level="info",
reload=True,
reload=False,
)
2 changes: 1 addition & 1 deletion langport/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
LANGPORT_VERSION = "0.2.0"
LANGPORT_VERSION = "0.2.1"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "langport"
version = "0.2.0"
version = "0.2.1"
description = "A large language model serving platform."
readme = "README.md"
requires-python = ">=3.8"
Expand Down

0 comments on commit 7ff7fbe

Please sign in to comment.