From e840b4eef470aec63740e93245402c706b5dc43c Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 1 Oct 2024 11:30:23 -0700 Subject: [PATCH] Fix non-MM multiturn: Use legacy formatting --- torchchat/usages/openai_api.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 8cdd8849d..9e3661fa5 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -21,6 +21,7 @@ from torchchat.cli.download import is_model_downloaded, load_model_configs from torchchat.generate import Generator, GeneratorArgs +from torchchat.model import FlamingoModel from torchchat.utils.build_utils import device_sync @@ -363,9 +364,24 @@ def chunked_completion(self, completion_request: CompletionRequest): device_sync(device=self.builder_args.device) - encoded, batch = self._gen_model_inputs_from_openai_completion_request( - completion_request - ) + # If the underlying model is LLama3.2 11B, used unified processing + if isinstance(self.model, FlamingoModel): + encoded, batch = self._gen_model_inputs_from_openai_completion_request( + completion_request + ) + else: + # Else use the legacy formatting logic + tokens = self.chat_formatter.encode_dialog_prompt( + dialog=[ + {"role": message["role"], "content": message["content"]} + for message in completion_request.messages + ] + ) + print("tokens:", self.tokenizer.decode(tokens), flush=True) + encoded = torch.tensor( + tokens, dtype=torch.int, device=self.builder_args.device + ) + batch = None idx = 0 start_pos = 0