Skip to content

Commit

Permalink
Lint and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Khuu committed Oct 1, 2024
1 parent da475eb commit 29f5204
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
53 changes: 39 additions & 14 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,8 +732,27 @@ def _callback(self, x, *, buffer, done_generating):
print("".join(buffer), end="", flush=True)
buffer.clear()
# print(, end='', flush=True)

def _gen_model_input(self, prompt: Union[str | List[Any]], image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None) -> Tuple:

def _gen_model_input(
self,
prompt: Union[str | List[Any]],
image_prompts: Optional[List[str | Image.Image]] = None,
max_new_tokens: Optional[int] = None,
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
"""
Convert prompt and image prompts into consumable model input args.
When prompt is a list, the anticipated format is OpenAI API Inspired:
[ ..., {"role": message["role"], "content": message["content"]}, ...]
Args:
prompt (Union[str, List[Any]]): Prompt or list of dialog.
image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B.
max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B.
Returns:
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
"""

# Not Llama 3.2 11B
if self.model.config.model_type != ModelType.Flamingo:
Expand All @@ -753,14 +772,20 @@ def _gen_model_input(self, prompt: Union[str | List[Any]], image_prompts: Option
return encoded, None

# Llama 3.2 11B
assert image_prompts is None or len(image_prompts) == 1, "At most one image is supported at the moment"
assert (
image_prompts is None or len(image_prompts) == 1
), "At most one image is supported at the moment"
if image_prompts and isinstance(image_prompts[0], str):
images = [Image.open(image_prompts[0])]
else:
images = image_prompts

assert max_new_tokens is not None, "max_new_tokens must be specified for Flamingo models"
assert isinstance(prompt, str), "(Currently) prompt must be a str for Flamingo models"
assert (
max_new_tokens is not None
), "max_new_tokens must be specified for Flamingo models"
assert isinstance(
prompt, str
), "(Currently) prompt must be a str for Flamingo models"

is_multimodal = images is not None
content = [{"type": "text", "content": prompt}]
Expand Down Expand Up @@ -791,21 +816,21 @@ def _gen_model_input(self, prompt: Union[str | List[Any]], image_prompts: Option
encoded = batch.pop("tokens").to(device).view(-1)
seq_len = encoded.size(0)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
self.dtype
)
else:
encoded = torch.tensor(
data["tokens"], device=device
).view(-1)
encoded = torch.tensor(data["tokens"], device=device).view(-1)
seq_len = encoded.size(0)
batch = {}

total_response_length = seq_len + max_new_tokens
batch["causal_mask"] = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
)
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
)
)

logging.debug(encoded)
return encoded, batch
Expand Down
4 changes: 3 additions & 1 deletion torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ def _gen_model_inputs_from_openai_completion_request(
{"role": message["role"], "content": message["content"]}
for message in completion_request.messages
]
return self._gen_model_input(prompt=prompt, max_new_tokens=completion_request.max_tokens)
return self._gen_model_input(
prompt=prompt, max_new_tokens=completion_request.max_tokens
)

# Llama 3.2 11B
prompt = None
Expand Down

0 comments on commit 29f5204

Please sign in to comment.