Skip to content

Commit b151de4

Browse files
ixlmarFunatiq
andauthored
[TRTLLM-8377][test] unit tests for TorchSampler batched sampling (#9012)
Signed-off-by: ixlmar <[email protected]> Co-authored-by: Robin Kobus <[email protected]>
1 parent b894dc2 commit b151de4

File tree

3 files changed

+1375
-18
lines changed

3 files changed

+1375
-18
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,15 +673,17 @@ def get_generator(self, device: torch.device) -> torch.Generator:
673673
assert self._generator.device == device
674674
return self._generator
675675

676-
def get_spec_tree_manager(self, resource_manager: ResourceManager) -> Optional[SpecTreeManager]:
676+
def get_spec_tree_manager(
677+
self, resource_manager: Optional[ResourceManager]
678+
) -> Optional[SpecTreeManager]:
677679
if resource_manager is None:
678680
return None
679681
spec_resource_manager = resource_manager.get_resource_manager(
680682
ResourceManagerType.SPEC_RESOURCE_MANAGER
681683
)
682684
if spec_resource_manager is None or not hasattr(spec_resource_manager, "spec_tree_manager"):
683685
return None
684-
return spec_resource_manager.spec_tree_manager
686+
return spec_resource_manager.spec_tree_manager # type: ignore
685687

686688
@staticmethod
687689
def _meet_max_token_stop_criteria(request: LlmRequest, max_seq_len: int):

tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,14 @@ def _make_tensor(data: list, dtype: torch.dtype, device: torch.device) -> torch.
9191
def _prepare_logits_with_temperature(
9292
logits: torch.Tensor,
9393
group_logit_indices: Optional[torch.Tensor],
94-
temperature: Optional[torch.Tensor],
94+
temperature: torch.Tensor,
9595
) -> torch.Tensor:
96-
if temperature is not None:
97-
temperature = temperature.unsqueeze(-1)
98-
if group_logit_indices is not None:
99-
logits = torch.index_select(logits, 0, group_logit_indices) # ensures copy
100-
logits /= temperature
101-
else:
102-
logits = logits / temperature # not inplace
103-
elif group_logit_indices is not None:
104-
logits = logits[group_logit_indices]
96+
temperature = temperature.unsqueeze(-1)
97+
if group_logit_indices is not None:
98+
logits = torch.index_select(logits, 0, group_logit_indices) # ensures copy
99+
logits /= temperature
100+
else:
101+
logits = logits / temperature # not inplace
105102
return logits
106103

107104
@staticmethod
@@ -112,12 +109,12 @@ def _prepare_probs_with_temperature(
112109
) -> torch.Tensor:
113110
if group_logit_indices is not None:
114111
logits = logits[group_logit_indices]
115-
logits = flashinfer.sampling.softmax(
112+
probs = flashinfer.sampling.softmax(
116113
logits,
117114
temperature,
118115
enable_pdl=ENABLE_PDL,
119116
)
120-
return logits
117+
return probs
121118

122119
@classmethod
123120
def _sample_from_probs(
@@ -151,7 +148,7 @@ def _sample_with_probs(
151148
group_logit_indices: Optional[torch.Tensor],
152149
top_k: Optional[torch.Tensor],
153150
top_p: Optional[torch.Tensor],
154-
temperature: Optional[torch.Tensor],
151+
temperature: torch.Tensor,
155152
generator: Optional[torch.Generator],
156153
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
157154
if top_k is not None:

0 commit comments

Comments
 (0)