Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Absorb non-MM OpenAI dialog parsing into generic input parsing #1248

Merged
merged 5 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 87 additions & 48 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,67 +732,106 @@ def _callback(self, x, *, buffer, done_generating):
print("".join(buffer), end="", flush=True)
buffer.clear()
# print(, end='', flush=True)

def _gen_model_input(self, prompt: str, image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None) -> Tuple:
assert image_prompts is None or len(image_prompts) == 1, "At most one image is supported at the moment"

def _gen_model_input(
self,
prompt: Union[str | List[Any]],
image_prompts: Optional[List[str | Image.Image]] = None,
max_new_tokens: Optional[int] = None,
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
"""
Convert prompt and image prompts into consumable model input args.

When prompt is a list, the anticipated format is OpenAI API Inspired:
[ ..., {"role": message["role"], "content": message["content"]}, ...]

Args:
prompt (Union[str, List[Any]]): Prompt or list of dialog.
image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B.
max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B.

Returns:
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
"""

# Not Llama 3.2 11B
if self.model.config.model_type != ModelType.Flamingo:
# Single String prompt
if isinstance(prompt, str):
encoded = self.encode_tokens(
prompt, bos=True, device=self.builder_args.device
)
# List of dialog
else:
tokens = self.chat_formatter.encode_dialog_prompt(prompt)
encoded = torch.tensor(
tokens, dtype=torch.int, device=self.builder_args.device
)

logging.debug(encoded)
return encoded, None

# Llama 3.2 11B
assert (
image_prompts is None or len(image_prompts) == 1
), "At most one image is supported at the moment"
if image_prompts and isinstance(image_prompts[0], str):
images = [Image.open(image_prompts[0])]
else:
images = image_prompts

if self.model.config.model_type == ModelType.Flamingo:
assert max_new_tokens is not None, "max_new_tokens must be specified for Flamingo models"
assert (
max_new_tokens is not None
), "max_new_tokens must be specified for Flamingo models"
assert isinstance(
prompt, str
), "(Currently) prompt must be a str for Flamingo models"

is_multimodal = images is not None
content = [{"type": "text", "content": prompt}]
is_multimodal = images is not None
content = [{"type": "text", "content": prompt}]

if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content
if is_multimodal:
content = [{"type": "image", "content": images[0]}] + content

messages = [
Message(
role="user",
content=content,
eot=True,
),
Message(role="assistant", content=""),
]
messages = [
Message(
role="user",
content=content,
eot=True,
),
Message(role="assistant", content=""),
]

transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))

device = torch.device(device=self.builder_args.device)
device = torch.device(device=self.builder_args.device)

with device, set_default_dtype(self.dtype):
data = transform({"messages": messages}, inference=True)
with device, set_default_dtype(self.dtype):
data = transform({"messages": messages}, inference=True)

if is_multimodal:
batch = padded_collate_tiled_images_and_mask(
[data], pad_direction="left", pad_max_images=1
)
encoded = batch.pop("tokens").to(device).view(-1)
seq_len = encoded.size(0)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
else:
encoded = torch.tensor(
data["tokens"], device=device
).view(-1)
seq_len = encoded.size(0)
batch = {}

total_response_length = seq_len + max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
)
else:
encoded = self.encode_tokens(
prompt, bos=True, device=self.builder_args.device
if is_multimodal:
batch = padded_collate_tiled_images_and_mask(
[data], pad_direction="left", pad_max_images=1
)
encoded = batch.pop("tokens").to(device).view(-1)
seq_len = encoded.size(0)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
self.dtype
)
else:
encoded = torch.tensor(data["tokens"], device=device).view(-1)
seq_len = encoded.size(0)
batch = {}

total_response_length = seq_len + max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
Comment on lines +783 to +832
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lint and white space

)
batch = None


logging.debug(encoded)
return encoded, batch

Expand Down
34 changes: 14 additions & 20 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,17 @@ def _gen_model_inputs_from_openai_completion_request(
"""
messages = completion_request.messages

# Not Llama 3.2 11B
if not isinstance(self.model, FlamingoModel):
prompt = [
{"role": message["role"], "content": message["content"]}
for message in completion_request.messages
]
return self._gen_model_input(
prompt=prompt, max_new_tokens=completion_request.max_tokens
)

# Llama 3.2 11B
prompt = None
images = None

Expand Down Expand Up @@ -361,27 +372,10 @@ def chunked_completion(self, completion_request: CompletionRequest):

# Initialize counters for chunk responses and encode the prompt.
id = str(uuid.uuid4())

device_sync(device=self.builder_args.device)

# If the underlying model is LLama3.2 11B, used unified processing
if isinstance(self.model, FlamingoModel):
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
completion_request
)
else:
# Else use the legacy formatting logic
tokens = self.chat_formatter.encode_dialog_prompt(
dialog=[
{"role": message["role"], "content": message["content"]}
for message in completion_request.messages
]
)
print("tokens:", self.tokenizer.decode(tokens), flush=True)
encoded = torch.tensor(
tokens, dtype=torch.int, device=self.builder_args.device
)
batch = None
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
completion_request
)

idx = 0
start_pos = 0
Expand Down
Loading