Skip to content

Commit

Permalink
Fix non-MM multiturn: Use legacy formatting (#1247)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Khuu authored Oct 1, 2024
1 parent 3c0f180 commit edaa15c
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from torchchat.cli.download import is_model_downloaded, load_model_configs
from torchchat.generate import Generator, GeneratorArgs
from torchchat.model import FlamingoModel

from torchchat.utils.build_utils import device_sync

Expand Down Expand Up @@ -363,9 +364,24 @@ def chunked_completion(self, completion_request: CompletionRequest):

device_sync(device=self.builder_args.device)

encoded, batch = self._gen_model_inputs_from_openai_completion_request(
completion_request
)
# If the underlying model is LLama3.2 11B, used unified processing
if isinstance(self.model, FlamingoModel):
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
completion_request
)
else:
# Else use the legacy formatting logic
tokens = self.chat_formatter.encode_dialog_prompt(
dialog=[
{"role": message["role"], "content": message["content"]}
for message in completion_request.messages
]
)
print("tokens:", self.tokenizer.decode(tokens), flush=True)
encoded = torch.tensor(
tokens, dtype=torch.int, device=self.builder_args.device
)
batch = None

idx = 0
start_pos = 0
Expand Down

0 comments on commit edaa15c

Please sign in to comment.