-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] Consolidate and optimize logic for building padded tensors #6541
[Misc] Consolidate and optimize logic for building padded tensors #6541
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some explanations
_STR_DTYPE_TO_TORCH_DTYPE = { | ||
"half": torch.half, | ||
"bfloat16": torch.bfloat16, | ||
"float": torch.float, | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found that this mapping is a duplicate of the one in vllm.utils
, so I've removed it.
@@ -466,22 +465,30 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], | |||
do_penalties = prompt_tokens or output_tokens | |||
|
|||
if do_penalties: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have merged the if-else blocks based on do_penalties
together. Not sure why they were separated in the first place.
|
||
The padding is applied to the end of each inner list until it reaches | ||
`max_len`. | ||
""" | ||
padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad | ||
padded_x = np.full((len(x), max_len), pad, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.full(..., pad)
is more efficient than np.zeros(...) + pad
. Try it yourself:
python -m timeit "import numpy as np; np.zeros(100000) + 2"
python -m timeit "import numpy as np; np.full(100000, 2)"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've also fixed the dtype to be consistent with the pytorch one.
64382b3
to
38c5ab8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just nits
…lm-project#6541) Signed-off-by: Alvant <[email protected]>
Following #6442 , this PR introduces a small refactor to clean up the code.
cc @peng1999