Skip to content

Commit

Permalink
Absorb non-MM OpenAI dialog parsing into generic input parsing (#1248)
Browse files Browse the repository at this point in the history
* Fix non-MM multiturn: Use legacy formatting

* Absorb non-MM OpenAI dialog parsing into generic input parsing

* Lint and docstrings
  • Loading branch information
Jack-Khuu authored Oct 2, 2024
1 parent edaa15c commit 58185b6
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 68 deletions.
135 changes: 87 additions & 48 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,67 +732,106 @@ def _callback(self, x, *, buffer, done_generating):
print("".join(buffer), end="", flush=True)
buffer.clear()
# print(, end='', flush=True)

def _gen_model_input(self, prompt: str, image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None) -> Tuple:
assert image_prompts is None or len(image_prompts) == 1, "At most one image is supported at the moment"

def _gen_model_input(
self,
prompt: Union[str | List[Any]],
image_prompts: Optional[List[str | Image.Image]] = None,
max_new_tokens: Optional[int] = None,
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
"""
Convert prompt and image prompts into consumable model input args.
When prompt is a list, the anticipated format is OpenAI API Inspired:
[ ..., {"role": message["role"], "content": message["content"]}, ...]
Args:
prompt (Union[str, List[Any]]): Prompt or list of dialog.
image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B.
max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B.
Returns:
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
"""

# Not Llama 3.2 11B
if self.model.config.model_type != ModelType.Flamingo:
# Single String prompt
if isinstance(prompt, str):
encoded = self.encode_tokens(
prompt, bos=True, device=self.builder_args.device
)
# List of dialog
else:
tokens = self.chat_formatter.encode_dialog_prompt(prompt)
encoded = torch.tensor(
tokens, dtype=torch.int, device=self.builder_args.device
)

logging.debug(encoded)
return encoded, None

# Llama 3.2 11B
assert (
image_prompts is None or len(image_prompts) == 1
), "At most one image is supported at the moment"
if image_prompts and isinstance(image_prompts[0], str):
images = [Image.open(image_prompts[0])]
else:
images = image_prompts

if self.model.config.model_type == ModelType.Flamingo:
assert max_new_tokens is not None, "max_new_tokens must be specified for Flamingo models"
assert (
max_new_tokens is not None
), "max_new_tokens must be specified for Flamingo models"
assert isinstance(
prompt, str
), "(Currently) prompt must be a str for Flamingo models"

is_multimodal = images is not None
content = [{"type": "text", "content": prompt}]
is_multimodal = images is not None
content = [{"type": "text", "content": prompt}]

if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content
if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content

messages = [
Message(
role="user",
content=content,
eot=True,
),
Message(role="assistant", content=""),
]
messages = [
Message(
role="user",
content=content,
eot=True,
),
Message(role="assistant", content=""),
]

transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))

device = torch.device(device=self.builder_args.device)
device = torch.device(device=self.builder_args.device)

with device, set_default_dtype(self.dtype):
data = transform({"messages": messages}, inference=True)
with device, set_default_dtype(self.dtype):
data = transform({"messages": messages}, inference=True)

if is_multimodal:
batch = padded_collate_tiled_images_and_mask(
[data], pad_direction="left", pad_max_images=1
)
encoded = batch.pop("tokens").to(device).view(-1)
seq_len = encoded.size(0)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
else:
encoded = torch.tensor(
data["tokens"], device=device
).view(-1)
seq_len = encoded.size(0)
batch = {}

total_response_length = seq_len + max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
)
else:
encoded = self.encode_tokens(
prompt, bos=True, device=self.builder_args.device
if is_multimodal:
batch = padded_collate_tiled_images_and_mask(
[data], pad_direction="left", pad_max_images=1
)
encoded = batch.pop("tokens").to(device).view(-1)
seq_len = encoded.size(0)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
self.dtype
)
else:
encoded = torch.tensor(data["tokens"], device=device).view(-1)
seq_len = encoded.size(0)
batch = {}

total_response_length = seq_len + max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
)
batch = None


logging.debug(encoded)
return encoded, batch

Expand Down
34 changes: 14 additions & 20 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,17 @@ def _gen_model_inputs_from_openai_completion_request(
"""
messages = completion_request.messages

# Not Llama 3.2 11B
if not isinstance(self.model, FlamingoModel):
prompt = [
{"role": message["role"], "content": message["content"]}
for message in completion_request.messages
]
return self._gen_model_input(
prompt=prompt, max_new_tokens=completion_request.max_tokens
)

# Llama 3.2 11B
prompt = None
images = None

Expand Down Expand Up @@ -361,27 +372,10 @@ def chunked_completion(self, completion_request: CompletionRequest):

# Initialize counters for chunk responses and encode the prompt.
id = str(uuid.uuid4())

device_sync(device=self.builder_args.device)

# If the underlying model is LLama3.2 11B, used unified processing
if isinstance(self.model, FlamingoModel):
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
completion_request
)
else:
# Else use the legacy formatting logic
tokens = self.chat_formatter.encode_dialog_prompt(
dialog=[
{"role": message["role"], "content": message["content"]}
for message in completion_request.messages
]
)
print("tokens:", self.tokenizer.decode(tokens), flush=True)
encoded = torch.tensor(
tokens, dtype=torch.int, device=self.builder_args.device
)
batch = None
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
completion_request
)

idx = 0
start_pos = 0
Expand Down

0 comments on commit 58185b6

Please sign in to comment.