From 2f893d25189a2b38b8762b7af6ad42bab1b08190 Mon Sep 17 00:00:00 2001 From: technillogue Date: Fri, 17 May 2024 18:46:13 -0400 Subject: [PATCH] format --- predict.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/predict.py b/predict.py index 644cbca..56e0572 100644 --- a/predict.py +++ b/predict.py @@ -16,7 +16,10 @@ from sse import receive_sse from triton_config_generator import generate_configs, load_yaml_config - from utils import StreamingTokenStopSequenceHandler, maybe_download_tarball_with_pget + from utils import ( + StreamingTokenStopSequenceHandler, + maybe_download_tarball_with_pget, + ) TRITONSERVER_DIST_DIR = ( pytriton.utils.distribution.get_root_module_path() / "tritonserver" @@ -65,8 +68,8 @@ async def setup(self, weights: str = "") -> None: print(f"tensorrt_llm config: {config}") max_seqlen_env = os.getenv("MAX_SEQUENCE_LENGTH", None) - if max_seqlen_env : - self.max_sequence_length = int(max_seqlen_env ) + if max_seqlen_env: + self.max_sequence_length = int(max_seqlen_env) else: try: self.max_sequence_length = self.trt_llm_config["pretrained_config"][ @@ -215,12 +218,16 @@ async def predict( if max_tokens == 512 or max_tokens is None: max_tokens = max_new_tokens else: - raise Exception(f"Can't set both max_tokens ({max_tokens}) and max_new_tokens ({max_new_tokens})") + raise Exception( + f"Can't set both max_tokens ({max_tokens}) and max_new_tokens ({max_new_tokens})" + ) if min_new_tokens: if min_tokens is None: min_tokens = min_new_tokens else: - raise Exception(f"Can't set both min_tokens ({min_tokens}) and min_new_tokens ({min_new_tokens})") + raise Exception( + f"Can't set both min_tokens ({min_tokens}) and min_new_tokens ({min_new_tokens})" + ) args = self._process_args( prompt=formatted_prompt, @@ -268,7 +275,7 @@ async def predict( tokens = np.append(tokens, token) output = self.tokenizer.decode(tokens, skip_special_tokens=True) # Catches partial emojis, waits for them to finish - output = output.replace("\N{Replacement Character}", "") + output = output.replace("\N{REPLACEMENT CHARACTER}", "") # Remove the tokens that were already yielded current_output = output[generation_length:]