Skip to content

Commit

Permalink
update changes for all endpoint types
Browse files Browse the repository at this point in the history
  • Loading branch information
SecretiveShell committed Sep 21, 2024
1 parent 035269c commit cef9a03
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 25 deletions.
10 changes: 2 additions & 8 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import pathlib
from asyncio import CancelledError
from copy import deepcopy
from typing import List, Optional
import json

Expand Down Expand Up @@ -291,13 +290,8 @@ async def stream_generate_chat_completion(
try:
logger.info(f"Received chat completion streaming request {request.state.id}")

gen_params = data.to_gen_params()

for n in range(0, data.n):
if n > 0:
task_gen_params = deepcopy(gen_params)
else:
task_gen_params = gen_params
task_gen_params = data.model_copy(deep=True)

gen_task = asyncio.create_task(
_stream_collector(
Expand All @@ -306,7 +300,7 @@ async def stream_generate_chat_completion(
prompt,
request.state.id,
abort_event,
**task_gen_params,
**task_gen_params.model_dump(),
)
)

Expand Down
22 changes: 5 additions & 17 deletions endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import asyncio
import pathlib
from asyncio import CancelledError
from copy import deepcopy
from fastapi import HTTPException, Request
from typing import List, Union

Expand Down Expand Up @@ -166,13 +165,8 @@ async def stream_generate_completion(
try:
logger.info(f"Received streaming completion request {request.state.id}")

gen_params = data.to_gen_params()

for n in range(0, data.n):
if n > 0:
task_gen_params = deepcopy(gen_params)
else:
task_gen_params = gen_params
task_gen_params = data.model_copy(deep=True)

gen_task = asyncio.create_task(
_stream_collector(
Expand All @@ -181,7 +175,7 @@ async def stream_generate_completion(
data.prompt,
request.state.id,
abort_event,
**task_gen_params,
**task_gen_params.model_dump(),
)
)

Expand Down Expand Up @@ -229,23 +223,17 @@ async def generate_completion(
"""Non-streaming generate for completions"""

gen_tasks: List[asyncio.Task] = []
gen_params = data.to_gen_params()

try:
logger.info(f"Recieved completion request {request.state.id}")

for n in range(0, data.n):
# Deepcopy gen params above the first index
# to ensure nested structures aren't shared
if n > 0:
task_gen_params = deepcopy(gen_params)
else:
task_gen_params = gen_params
for _ in range(0, data.n):
task_gen_params = data.model_copy(deep=True)

gen_tasks.append(
asyncio.create_task(
model.container.generate(
data.prompt, request.state.id, **task_gen_params
data.prompt, request.state.id, **task_gen_params.model_dump()
)
)
)
Expand Down

0 comments on commit cef9a03

Please sign in to comment.