Skip to content

Commit

Permalink
feat(model): Passing stop parameter to proxyllm (#2077)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Oct 18, 2024
1 parent cf192a5 commit 53ba625
Show file tree
Hide file tree
Showing 13 changed files with 31 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dbgpt/core/interface/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class ModelRequest:
max_new_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""

stop: Optional[str] = None
stop: Optional[Union[str, List[str]]] = None
"""The stop condition of the model inference."""
stop_token_ids: Optional[List[int]] = None
"""The stop token ids of the model inference."""
Expand Down
13 changes: 13 additions & 0 deletions dbgpt/model/cluster/apiserver/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, code: int, message: str):
class APISettings(BaseModel):
api_keys: Optional[List[str]] = None
embedding_bach_size: int = 4
ignore_stop_exceeds_error: bool = False


api_settings = APISettings()
Expand Down Expand Up @@ -146,6 +147,15 @@ def check_requests(request) -> Optional[JSONResponse]:
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.stop} is not valid under any of the given schemas - 'stop'",
)
if request.stop and isinstance(request.stop, list) and len(request.stop) > 4:
# https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop
if not api_settings.ignore_stop_exceeds_error:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"Invalid 'stop': array too long. Expected an array with maximum length 4, but got an array with length {len(request.stop)} instead.",
)
else:
request.stop = request.stop[:4]

return None

Expand Down Expand Up @@ -581,6 +591,7 @@ def initialize_apiserver(
port: int = None,
api_keys: List[str] = None,
embedding_batch_size: Optional[int] = None,
ignore_stop_exceeds_error: bool = False,
):
import os

Expand Down Expand Up @@ -614,6 +625,7 @@ def initialize_apiserver(

if embedding_batch_size:
api_settings.embedding_bach_size = embedding_batch_size
api_settings.ignore_stop_exceeds_error = ignore_stop_exceeds_error

app.include_router(router, prefix="/api", tags=["APIServer"])

Expand Down Expand Up @@ -664,6 +676,7 @@ def run_apiserver():
port=apiserver_params.port,
api_keys=api_keys,
embedding_batch_size=apiserver_params.embedding_batch_size,
ignore_stop_exceeds_error=apiserver_params.ignore_stop_exceeds_error,
)


Expand Down
4 changes: 2 additions & 2 deletions dbgpt/model/cluster/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from dbgpt._private.pydantic import BaseModel
from dbgpt.core.interface.message import ModelMessage
Expand All @@ -15,7 +15,7 @@ class PromptRequest(BaseModel):
prompt: str = None
temperature: float = None
max_new_tokens: int = None
stop: str = None
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: List[int] = []
context_len: int = None
echo: bool = True
Expand Down
3 changes: 3 additions & 0 deletions dbgpt/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ class ModelAPIServerParameters(BaseServerParameters):
embedding_batch_size: Optional[int] = field(
default=None, metadata={"help": "Embedding batch size"}
)
ignore_stop_exceeds_error: Optional[bool] = field(
default=False, metadata={"help": "Ignore exceeds stop words error"}
)

log_file: Optional[str] = field(
default="dbgpt_model_apiserver.log",
Expand Down
3 changes: 3 additions & 0 deletions dbgpt/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async def chatgpt_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
async for r in client.generate_stream(request):
yield r
Expand Down Expand Up @@ -188,6 +189,8 @@ def _build_request(
payload["temperature"] = request.temperature
if request.max_new_tokens:
payload["max_tokens"] = request.max_new_tokens
if request.stop:
payload["stop"] = request.stop
return payload

async def generate(
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ async def deepseek_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
async for r in client.generate_stream(request):
yield r
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def gemini_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.sync_generate_stream(request):
yield r
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/moonshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def moonshot_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
async for r in client.generate_stream(request):
yield r
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def spark_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.sync_generate_stream(request):
yield r
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/model/proxy/llms/tongyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def tongyi_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.sync_generate_stream(request):
yield r
Expand Down Expand Up @@ -96,6 +97,7 @@ def sync_generate_stream(
top_p=0.8,
stream=True,
result_format="message",
stop=request.stop,
)
for r in res:
if r:
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/yi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def yi_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
async for r in client.generate_stream(request):
yield r
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/zhipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def zhipu_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.sync_generate_stream(request):
yield r
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/rag/knowledge/docx.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _load(self) -> List[Document]:
documents = self._loader.load()
else:
docs = []
_SerializedRelationships.load_from_xml = load_from_xml_v2 # type: ignore
_SerializedRelationships.load_from_xml = load_from_xml_v2 # type: ignore
doc = docx.Document(self._path)
content = []

Expand Down

0 comments on commit 53ba625

Please sign in to comment.