Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
technillogue committed May 17, 2024
1 parent 72f560d commit 2f893d2
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

from sse import receive_sse
from triton_config_generator import generate_configs, load_yaml_config
from utils import StreamingTokenStopSequenceHandler, maybe_download_tarball_with_pget
from utils import (
StreamingTokenStopSequenceHandler,
maybe_download_tarball_with_pget,
)

TRITONSERVER_DIST_DIR = (
pytriton.utils.distribution.get_root_module_path() / "tritonserver"
Expand Down Expand Up @@ -65,8 +68,8 @@ async def setup(self, weights: str = "") -> None:
print(f"tensorrt_llm config: {config}")

max_seqlen_env = os.getenv("MAX_SEQUENCE_LENGTH", None)
if max_seqlen_env :
self.max_sequence_length = int(max_seqlen_env )
if max_seqlen_env:
self.max_sequence_length = int(max_seqlen_env)
else:
try:
self.max_sequence_length = self.trt_llm_config["pretrained_config"][
Expand Down Expand Up @@ -215,12 +218,16 @@ async def predict(
if max_tokens == 512 or max_tokens is None:
max_tokens = max_new_tokens
else:
raise Exception(f"Can't set both max_tokens ({max_tokens}) and max_new_tokens ({max_new_tokens})")
raise Exception(
f"Can't set both max_tokens ({max_tokens}) and max_new_tokens ({max_new_tokens})"
)
if min_new_tokens:
if min_tokens is None:
min_tokens = min_new_tokens
else:
raise Exception(f"Can't set both min_tokens ({min_tokens}) and min_new_tokens ({min_new_tokens})")
raise Exception(
f"Can't set both min_tokens ({min_tokens}) and min_new_tokens ({min_new_tokens})"
)

args = self._process_args(
prompt=formatted_prompt,
Expand Down Expand Up @@ -268,7 +275,7 @@ async def predict(
tokens = np.append(tokens, token)
output = self.tokenizer.decode(tokens, skip_special_tokens=True)
# Catches partial emojis, waits for them to finish
output = output.replace("\N{Replacement Character}", "")
output = output.replace("\N{REPLACEMENT CHARACTER}", "")
# Remove the tokens that were already yielded
current_output = output[generation_length:]

Expand Down

0 comments on commit 2f893d2

Please sign in to comment.