diff --git a/torchchat/generate.py b/torchchat/generate.py index c38fcaff5..349ee45c5 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -732,67 +732,106 @@ def _callback(self, x, *, buffer, done_generating): print("".join(buffer), end="", flush=True) buffer.clear() # print(, end='', flush=True) - - def _gen_model_input(self, prompt: str, image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None) -> Tuple: - assert image_prompts is None or len(image_prompts) == 1, "At most one image is supported at the moment" + + def _gen_model_input( + self, + prompt: Union[str | List[Any]], + image_prompts: Optional[List[str | Image.Image]] = None, + max_new_tokens: Optional[int] = None, + ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: + """ + Convert prompt and image prompts into consumable model input args. + + When prompt is a list, the anticipated format is OpenAI API Inspired: + [ ..., {"role": message["role"], "content": message["content"]}, ...] + + Args: + prompt (Union[str, List[Any]]): Prompt or list of dialog. + image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B. + max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B. + + Returns: + Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models. + """ + + # Not Llama 3.2 11B + if self.model.config.model_type != ModelType.Flamingo: + # Single String prompt + if isinstance(prompt, str): + encoded = self.encode_tokens( + prompt, bos=True, device=self.builder_args.device + ) + # List of dialog + else: + tokens = self.chat_formatter.encode_dialog_prompt(prompt) + encoded = torch.tensor( + tokens, dtype=torch.int, device=self.builder_args.device + ) + + logging.debug(encoded) + return encoded, None + + # Llama 3.2 11B + assert ( + image_prompts is None or len(image_prompts) == 1 + ), "At most one image is supported at the moment" if image_prompts and isinstance(image_prompts[0], str): images = [Image.open(image_prompts[0])] else: images = image_prompts - if self.model.config.model_type == ModelType.Flamingo: - assert max_new_tokens is not None, "max_new_tokens must be specified for Flamingo models" + assert ( + max_new_tokens is not None + ), "max_new_tokens must be specified for Flamingo models" + assert isinstance( + prompt, str + ), "(Currently) prompt must be a str for Flamingo models" - is_multimodal = images is not None - content = [{"type": "text", "content": prompt}] + is_multimodal = images is not None + content = [{"type": "text", "content": prompt}] - if is_multimodal: - content = [{"type": "image", "content": images[0]}] + content + if is_multimodal: + content = [{"type": "image", "content": images[0]}] + content - messages = [ - Message( - role="user", - content=content, - eot=True, - ), - Message(role="assistant", content=""), - ] + messages = [ + Message( + role="user", + content=content, + eot=True, + ), + Message(role="assistant", content=""), + ] - transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path)) + transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path)) - device = torch.device(device=self.builder_args.device) + device = torch.device(device=self.builder_args.device) - with device, set_default_dtype(self.dtype): - data = transform({"messages": messages}, inference=True) + with device, set_default_dtype(self.dtype): + data = transform({"messages": messages}, inference=True) - if is_multimodal: - batch = padded_collate_tiled_images_and_mask( - [data], pad_direction="left", pad_max_images=1 - ) - encoded = batch.pop("tokens").to(device).view(-1) - seq_len = encoded.size(0) - batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] - batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype) - else: - encoded = torch.tensor( - data["tokens"], device=device - ).view(-1) - seq_len = encoded.size(0) - batch = {} - - total_response_length = seq_len + max_new_tokens - batch["causal_mask"] = torch.tril( - torch.ones( - size=(total_response_length, total_response_length), - dtype=torch.bool, - ) - ) - else: - encoded = self.encode_tokens( - prompt, bos=True, device=self.builder_args.device + if is_multimodal: + batch = padded_collate_tiled_images_and_mask( + [data], pad_direction="left", pad_max_images=1 + ) + encoded = batch.pop("tokens").to(device).view(-1) + seq_len = encoded.size(0) + batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] + batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to( + self.dtype + ) + else: + encoded = torch.tensor(data["tokens"], device=device).view(-1) + seq_len = encoded.size(0) + batch = {} + + total_response_length = seq_len + max_new_tokens + batch["causal_mask"] = torch.tril( + torch.ones( + size=(total_response_length, total_response_length), + dtype=torch.bool, + ) ) - batch = None - + logging.debug(encoded) return encoded, batch diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 9e3661fa5..93de6e0ec 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -310,6 +310,17 @@ def _gen_model_inputs_from_openai_completion_request( """ messages = completion_request.messages + # Not Llama 3.2 11B + if not isinstance(self.model, FlamingoModel): + prompt = [ + {"role": message["role"], "content": message["content"]} + for message in completion_request.messages + ] + return self._gen_model_input( + prompt=prompt, max_new_tokens=completion_request.max_tokens + ) + + # Llama 3.2 11B prompt = None images = None @@ -361,27 +372,10 @@ def chunked_completion(self, completion_request: CompletionRequest): # Initialize counters for chunk responses and encode the prompt. id = str(uuid.uuid4()) - device_sync(device=self.builder_args.device) - - # If the underlying model is LLama3.2 11B, used unified processing - if isinstance(self.model, FlamingoModel): - encoded, batch = self._gen_model_inputs_from_openai_completion_request( - completion_request - ) - else: - # Else use the legacy formatting logic - tokens = self.chat_formatter.encode_dialog_prompt( - dialog=[ - {"role": message["role"], "content": message["content"]} - for message in completion_request.messages - ] - ) - print("tokens:", self.tokenizer.decode(tokens), flush=True) - encoded = torch.tensor( - tokens, dtype=torch.int, device=self.builder_args.device - ) - batch = None + encoded, batch = self._gen_model_inputs_from_openai_completion_request( + completion_request + ) idx = 0 start_pos = 0