diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index cb79f89a8ae..ebc851a893c 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -467,6 +467,9 @@ class GenericLlmRequest initialize(req.getInputTokenIds(), req.getOutputConfig().returnLogProbs); } + GenericLlmRequest(GenericLlmRequest&& request) = default; + GenericLlmRequest(GenericLlmRequest const& request) = default; + void setExcludeInputFromOutput(bool exclude) { mExcludeInputFromOutput = exclude; @@ -2318,6 +2321,9 @@ class LlmRequest : public GenericLlmRequest mKvCacheRetentionConfig = request.getKvCacheRetentionConfig(); } + LlmRequest(LlmRequest&& request) = default; + LlmRequest(LlmRequest const& request) = default; + /// @brief Create a Response from the current state of the request /// @details Note that there is some dependency on the order of operations in this method. Modify with care! /// @return An optional Response diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 9703631ff58..e4d0237048f 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -187,6 +187,8 @@ void initBindings(nb::module_& m) .def_prop_ro("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest) .def_prop_ro("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest) .def_prop_ro("llm_request_type", &GenLlmReq::getLlmRequestType) + .def_prop_ro("parent_request_id", &GenLlmReq::getParentRequestId) + .def_prop_ro("is_child", &GenLlmReq::isChild) .def_prop_ro("multimodal_hashes", [](GenLlmReq& self) { @@ -351,11 +353,13 @@ void initBindings(nb::module_& m) nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt, nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt, nb::arg("context_phase_params") = std::nullopt) + .def(nb::init()) .def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"), nb::arg("max_draft_len"), nb::arg("vocab_size_padded"), nb::arg("max_endocer_input_len") = std::nullopt, nb::arg("enable_kv_cache_reuse") = false) .def("create_response", &tb::LlmRequest::createResponse, nb::arg("use_fast_logits") = false, nb::arg("mpi_world_rank") = 0) + .def("create_child_request", &tb::LlmRequest::createChildRequest, nb::arg("child_id")) .def("create_result", &tb::LlmRequest::createResult, nb::arg("use_fast_logits") = false, nb::arg("mpi_world_rank") = 0) .def("create_serialized_result", diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index a68086f9d68..d92b60cfbac 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -192,6 +192,8 @@ void initBindings(pybind11::module_& m) .def_property_readonly("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest) .def_property_readonly("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest) .def_property_readonly("llm_request_type", &GenLlmReq::getLlmRequestType) + .def_property_readonly("parent_request_id", &GenLlmReq::getParentRequestId) + .def_property_readonly("is_child", &GenLlmReq::isChild) .def_property_readonly("multimodal_hashes", [](GenLlmReq& self) { @@ -254,7 +256,7 @@ void initBindings(pybind11::module_& m) .def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics); py::classh(m, "LlmRequest", pybind11::dynamic_attr()) - .def(py::init( + .def(py::init<>( [](tb::LlmRequest::RequestIdType request_id, tb::LlmRequest::SizeType32 max_new_tokens, std::vector input_tokens, runtime::SamplingConfig sampling_config, bool is_streaming, std::optional end_id, @@ -357,11 +359,13 @@ void initBindings(pybind11::module_& m) py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt, py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt, py::arg("context_phase_params") = std::nullopt) + .def(py::init()) .def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"), py::arg("max_draft_len"), py::arg("vocab_size_padded"), py::arg("max_endocer_input_len") = std::nullopt, py::arg("enable_kv_cache_reuse") = false) .def("create_response", &tb::LlmRequest::createResponse, py::arg("use_fast_logits") = false, py::arg("mpi_world_rank") = 0) + .def("create_child_request", &tb::LlmRequest::createChildRequest, py::arg("child_id")) .def("create_result", &tb::LlmRequest::createResult, py::arg("use_fast_logits") = false, py::arg("mpi_world_rank") = 0) .def("create_serialized_result", diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 9ee456ce908..9e928781277 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -107,6 +107,8 @@ def add_llm_args(parser): parser.add_argument("--top_k", type=int, default=None) parser.add_argument("--top_p", type=float, default=None) parser.add_argument('--load_format', type=str, default='auto') + parser.add_argument('--n', type=int, default=1) + parser.add_argument('--best_of', type=int, default=None) parser.add_argument('--max_beam_width', type=int, default=1) # Speculative decoding @@ -193,6 +195,7 @@ def setup_llm(args, **kwargs): batch_sizes=args.cuda_graph_batch_sizes, enable_padding=args.cuda_graph_padding_enabled, ) if args.use_cuda_graph else None + llm = LLM( model=args.model_dir, backend='pytorch', @@ -228,6 +231,15 @@ def setup_llm(args, **kwargs): **kwargs, ) + use_beam_search = args.max_beam_width > 1 + best_of = args.best_of or args.n + if use_beam_search: + if args.n == 1 and args.best_of is None: + args.n = args.max_beam_width + assert best_of <= args.max_beam_width, f"beam width: {best_of}, should be less or equal to max_beam_width: {args.max_beam_width}" + + assert best_of >= args.n, f"In sampling mode best_of value: {best_of} should be less or equal to n: {args.n}" + sampling_params = SamplingParams( max_tokens=args.max_tokens, temperature=args.temperature, @@ -236,8 +248,9 @@ def setup_llm(args, **kwargs): return_context_logits=args.return_context_logits, return_generation_logits=args.return_generation_logits, logprobs=args.logprobs, - n=args.max_beam_width, - use_beam_search=args.max_beam_width > 1) + n=args.n, + best_of=best_of, + use_beam_search=use_beam_search) return llm, sampling_params @@ -250,23 +263,23 @@ def main(): for i, output in enumerate(outputs): prompt = output.prompt - for beam_idx, beam in enumerate(output.outputs): - generated_text = beam.text + for sequence_idx, sequence in enumerate(output.outputs): + generated_text = sequence.text # Skip printing the beam_idx if no beam search was used - beam_id_text = f"[{beam_idx}]" if args.max_beam_width > 1 else "" + sequence_id_text = f"[{sequence_idx}]" if args.max_beam_width > 1 or args.n > 1 else "" print( - f"[{i}]{beam_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}" + f"[{i}]{sequence_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}" ) if args.return_context_logits: print( - f"[{i}]{beam_id_text} Context logits: {output.context_logits}" + f"[{i}]{sequence_id_text} Context logits: {output.context_logits}" ) if args.return_generation_logits: print( - f"[{i}]{beam_id_text} Generation logits: {beam.generation_logits}" + f"[{i}]{sequence_id_text} Generation logits: {sequence.generation_logits}" ) if args.logprobs: - print(f"[{i}]{beam_id_text} Logprobs: {beam.logprobs}") + print(f"[{i}]{sequence_id_text} Logprobs: {sequence.logprobs}") if __name__ == '__main__': diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index a6f12962314..91e2d091552 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -22,6 +22,7 @@ class RequestQueueItem: id: int request: Optional[ExecutorRequest] = None + child_req_ids: Optional[list] = None is_canceled_request: bool = False query: Optional[list] = None # only used in `StarAttention` @@ -83,6 +84,12 @@ def _get_from_request_queue( pass return items + @staticmethod + def _get_num_child_requests(request: ExecutorRequest) -> int: + sampling_config = request.sampling_config + return 0 if sampling_config.beam_width > 1 else ( + sampling_config.num_return_sequences or 1) - 1 + def _get_from_waiting_queue( self, waiting_queue: deque[RequestQueueItem], @@ -111,6 +118,11 @@ def _get_from_waiting_queue( scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests.copy( ) if enable_attention_dp else None while req_count < max_req_count and waiting_queue: + req_item = waiting_queue[0] + num_children = len( + req_item.child_req_ids) if req_item.child_req_ids else 0 + if (req_count + 1 + num_children) > max_req_count: + break req_item = waiting_queue.popleft() can_process = self._can_process_attention_dp_request( req_item, scheduling_all_ranks_num_active_requests @@ -118,7 +130,7 @@ def _get_from_waiting_queue( if can_process: items.append(req_item) - req_count += 1 + req_count += 1 + num_children else: pending_requests.append(req_item) @@ -149,17 +161,43 @@ def _can_process_attention_dp_request( return False + def _get_request_id(self): + # (next_request_id + 1) % UINT64_MAX + current_id = self.next_request_id + self.next_request_id = (self.next_request_id + 1) & ((1 << 64) - 1) + return current_id + + def _generate_child_request_ids( + self, request: ExecutorRequest) -> List[int] | None: + """ Generate child request IDs if needed. """ + child_req_ids = None + num_children = self._get_num_child_requests(request) + if num_children > 0: + child_req_ids = [] + for _ in range(num_children): + child_req_id = self._get_request_id() + if self.enable_iter_perf_stats: + self.start_times[child_req_id] = time.time() + child_req_ids.append(child_req_id) + + return child_req_ids + def enqueue_requests(self, requests: List[ExecutorRequest]): req_ids = [] try: self.enqueue_lock.acquire() - start_time = time.time() for request in requests: - self.start_times[self.next_request_id] = start_time + req_id = self._get_request_id() + + if self.enable_iter_perf_stats: + self.start_times[req_id] = time.time() + + child_req_ids = self._generate_child_request_ids(request) self.request_queue.put( - RequestQueueItem(self.next_request_id, request)) - req_ids.append(self.next_request_id) - self.next_request_id += 1 + RequestQueueItem(req_id, request, child_req_ids, + query=None)) + + req_ids.append(req_id) finally: self.enqueue_lock.release() return req_ids @@ -186,15 +224,18 @@ def enqueue_request(self, try: self.enqueue_lock.acquire() assert self.active, "PyExecutor has already been shutdown." - req_id = self.next_request_id + req_id = self._get_request_id() if self.enable_iter_perf_stats: self.start_times[req_id] = time.time() - if query is not None: - self.request_queue.put(RequestQueueItem(req_id, request, query)) - else: - self.request_queue.put(RequestQueueItem(req_id, request)) - self.next_request_id += 1 + child_req_ids = self._generate_child_request_ids(request) + self.request_queue.put( + RequestQueueItem( + req_id, + request, + child_req_ids=child_req_ids, + query=query, + )) finally: self.enqueue_lock.release() @@ -530,6 +571,10 @@ def _update_new_active_requests_queue_latency( if req_item.id in self.start_times: self.new_active_requests_queue_latency_ms += now - self.start_times.pop( req_item.id) + if req_item.child_req_ids: + for child_id in req_item.child_req_ids: + self.new_active_requests_queue_latency_ms += now - self.start_times.pop( + child_id) @nvtx_range("_merge_requests") def _merge_requests(self, new_requests: list[RequestQueueItem]): @@ -543,12 +588,15 @@ def _merge_requests(self, new_requests: list[RequestQueueItem]): else: raise NotImplementedError(f'unsupport cp type {cp_type}') else: - return [ - executor_request_to_llm_request( - req_item.id, req_item.request, + req_with_children = [] + for req_item in new_requests: + req = executor_request_to_llm_request( + req_item.id, req_item.request, req_item.child_req_ids, self._should_exclude_last_generation_logits()) - for req_item in new_requests - ] + req_with_children.append(req) + if req.child_requests: + req_with_children.extend(req.child_requests) + return req_with_children def _merge_star_attention_requests(self, new_requests: list[RequestQueueItem]): diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 2570a39346e..6bce46adc53 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -1,3 +1,4 @@ +from copy import deepcopy from dataclasses import dataclass from typing import List, Optional, Union @@ -277,22 +278,28 @@ def __init__( exclude_last_generation_logits: bool = False, return_perf_metrics: bool = False, stop_words_list: list[list[int]] | None = None, + llm_request: Optional[ + tensorrt_llm.bindings.internal.batch_manager.LlmRequest] = None, is_draft: bool = False, **kwargs): + self.py_logits_post_processors = kwargs.pop("py_logits_post_processors", None) # Multimodal data self.py_multimodal_data = kwargs.pop("py_multimodal_data", None) - super().__init__( - *args, - client_id=client_id, - return_log_probs=return_log_probs, - return_context_logits=False, - return_generation_logits=False, - return_perf_metrics=return_perf_metrics, - stop_words_list=torch.tensor(stop_words_list, dtype=torch.int32) - if stop_words_list else None, - **kwargs) + if llm_request is not None: + super().__init__(llm_request) + else: + super().__init__( + *args, + client_id=client_id, + return_log_probs=return_log_probs, + return_context_logits=False, + return_generation_logits=False, + return_perf_metrics=return_perf_metrics, + stop_words_list=torch.tensor(stop_words_list, dtype=torch.int32) + if stop_words_list else None, + **kwargs) self.py_client_id = client_id self.py_request_id = self.request_id self.py_llm_request_type = self.llm_request_type @@ -327,6 +334,7 @@ def __init__( return_log_probs, return_context_logits, return_generation_logits, exclude_last_generation_logits) + self.child_requests = [] def is_generation_only_request(self): return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY @@ -338,7 +346,8 @@ def create_response( result, is_final = super().create_serialized_result( use_fast_logits, mpi_world_rank) return LlmResponse( - request_id=self.py_request_id, + request_id=self.py_request_id + if self.is_child else self.parent_request_id, result=LlmResult(result, self.py_result, is_final), client_id=self.py_client_id) if len(result) > 0 else None @@ -351,6 +360,32 @@ def finish_by(self, reason: FinishReason, beam: int) -> None: self.state = LlmRequestState.GENERATION_COMPLETE self.set_finished_reason(reason, beam) + def create_child_request(self, child_id): + child = super().create_child_request(child_id) + py_request = LlmRequest(llm_request=child) + + # Copy all py_* attributes from parent to child + for attr_name, attr_value in self.__dict__.items(): + if attr_name.startswith('py_'): + attr_value = getattr(self, attr_name) + setattr(py_request, attr_name, deepcopy(attr_value)) + elif attr_name in ['is_attention_dp_dummy', 'is_cuda_graph_dummy']: + setattr(py_request, attr_name, attr_value) + + # Rewrite specific attributes that should use child_request values. + py_request.py_request_id = child.request_id + py_request.py_batch_idx = None + py_request.py_seq_slot = None + + py_request.child_requests = [] + + assert py_request.is_child + assert py_request.request_id == child.request_id + assert py_request.parent_request_id == self.request_id + assert py_request.sampling_config.random_seed != self.sampling_config.random_seed + + self.child_requests.append(py_request) + def convert_wordlist(word_list) -> List[List[int]]: """Converts a wordlist from format: @@ -392,6 +427,7 @@ def convert_wordlist(word_list) -> List[List[int]]: def executor_request_to_llm_request( req_id: int, executor_request: ExecutorRequest, + child_req_ids: List[int], exclude_last_generation_logits: bool, input_token_ids: Optional[List] = None) -> LlmRequest: executor_sampling_config = executor_request.sampling_config @@ -476,6 +512,10 @@ def executor_request_to_llm_request( context_phase_params=executor_request.context_phase_params, py_multimodal_data=getattr(executor_request, "py_multimodal_data", None)) + if child_req_ids: + for child_id in child_req_ids: + llm_request.create_child_request(child_id) + return llm_request diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 144786614c3..cbc13acd522 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -8,7 +8,7 @@ import traceback import weakref from contextlib import contextmanager -from typing import Dict, List, Optional, Union +from typing import List, Optional, Tuple, Union import torch from cuda import cudart @@ -170,7 +170,7 @@ def __init__(self, self.disable_overlap_scheduler = disable_overlap_scheduler # enqueue and _fetch_new_requests used data - self.next_req_id = max_batch_size # The first max_batch_size request IDs are reserved for dummy requests + self.active = True self.max_beam_width = max_beam_width self.max_draft_len = max_draft_len self.print_log = model_engine.pytorch_backend_config.print_iter_log @@ -1421,7 +1421,7 @@ def _handle_canceled_requests(self): self.executor_request_queue.update_waiting_queue() for request in self.active_requests: - req_id = request.py_request_id + req_id = request.py_request_id if not request.is_child else request.parent_request_id if req_id in self.executor_request_queue.get_canceled_req_ids(): # Mark requests as finished, then, we reuse all existing code # to clean up the KV cache resources. @@ -1435,7 +1435,7 @@ def _handle_canceled_requests(self): self.executor_request_queue.clear_canceled_req_ids() @nvtx_range("_enqueue_responses") - def _enqueue_responses(self, responses: Dict[int, LlmResponse]): + def _enqueue_responses(self, responses: List[Tuple[int, LlmResponse]]): if 0 not in self.dist.mapping.tp_group and not self.gather_all_responses: return @@ -1447,18 +1447,18 @@ def _enqueue_responses(self, responses: Dict[int, LlmResponse]): else: responses_list = self.dist.allgather(responses) if self.dist.rank == 0 or self.gather_all_responses: - gather_responses = {} + gather_responses = [] if responses_list is not None: for resp in responses_list: if resp is not None: - gather_responses.update(resp) + gather_responses.extend(resp) responses = gather_responses logger.debug( f'after gather, rank = {self.dist.rank}, responses = {responses}') if self.dist.rank == 0 or self.gather_all_responses: with self.response_cv: - for req_id, resp in responses.items(): + for req_id, resp in responses: if req_id in self.responses.keys(): self.responses[req_id].append(resp) else: @@ -1467,20 +1467,20 @@ def _enqueue_responses(self, responses: Dict[int, LlmResponse]): @nvtx_range("_handle_first_token_response") def _handle_first_token_response(self, scheduled_batch): - new_responses = {} + new_responses = [] for req in scheduled_batch.generation_requests: if req.py_decoding_iter == 1: logger.debug( f'Send first token response for request {req.py_request_id}' ) response = req.create_response(False, self.dist.rank) - new_responses.update({req.py_request_id: response}) + new_responses.append((req.py_request_id, response)) self._enqueue_responses(new_responses) @nvtx_range("_handle_responses") def _handle_responses(self): - new_responses = {} + new_responses = [] requests_to_terminate = [] new_active_requests = [] logger.debug( @@ -1514,8 +1514,8 @@ def _handle_responses(self): request.py_decoding_iter % self.stream_interval == 0: response = request.create_response(False, self.dist.rank) if response: - request_done = response.result.is_final - new_responses.update({req_id: response}) + request_done = request.is_finished + new_responses.append((req_id, response)) if request_done: if request.is_disagg_context_transmission_state: diff --git a/tests/unittest/_torch/test_best_of_n.py b/tests/unittest/_torch/test_best_of_n.py new file mode 100644 index 00000000000..89653269d6b --- /dev/null +++ b/tests/unittest/_torch/test_best_of_n.py @@ -0,0 +1,135 @@ +import os + +import pytest +from utils.llm_data import llm_models_root +from utils.util import force_ampere, similar + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, + SamplingConfig) +from tensorrt_llm.llmapi.llm_utils import KvCacheConfig + + +@pytest.fixture(scope="module") +def input_prompts(): + return [ + "Born in north-east France, Soyer trained as a", + "The future of AI is", + ] + + +@pytest.fixture(scope="module") +def expected_outputs(): + return { + "Born in north-east France, Soyer trained as a": [ + "lawyer and was a member of the French Resistance", + "cook before turning to painting." + ], + "The future of AI is": [ + "all about human-machine collaboration.", + "more promising than you think." + ], + } + + +@pytest.fixture(scope="module") +def llm(): + return LLM(model=os.path.join(llm_models_root(), "llama-models-v2", + "TinyLlama-1.1B-Chat-v1.0"), + kv_cache_config=KvCacheConfig(max_tokens=1000), + max_batch_size=8, + max_seq_len=64, + enable_trtllm_sampler=True, + disable_overlap_scheduler=True) + + +@force_ampere # Save H100 resource +@pytest.mark.parametrize("n", [1, 2, 3]) +def test_create_child_request(n: int): + sampling_config = SamplingConfig() + setattr(sampling_config, 'top_p', [0.9]) + setattr(sampling_config, 'num_return_sequences', n) + + parent = LlmRequest( + request_id=1, + max_new_tokens=10, + input_tokens=[1, 2, 3], + sampling_config=sampling_config, + is_streaming=False, + client_id=50, + return_log_probs=True, + return_context_logits=True, + ) + + for child_id in range( + parent.request_id + 1, + parent.request_id + parent.sampling_config.num_return_sequences): + parent.create_child_request(child_id) + + assert len(parent.child_requests + ) == parent.sampling_config.num_return_sequences - 1 + + for ind, child in enumerate(parent.child_requests): + assert child.request_id == ind + parent.request_id + 1 + assert child.py_request_id == child.request_id + assert child.parent_request_id == parent.request_id + + assert child.py_client_id == 50 + assert child.py_max_new_tokens == 10 + + assert child.py_return_log_probs == parent.py_return_log_probs + assert child.py_return_context_logits == parent.py_return_context_logits + + assert child.py_batch_idx is None + + # Verify parent - child independence + assert child.py_result is not None + assert child.py_result is not parent.py_result + assert child.get_tokens() == parent.get_tokens() + assert child.get_tokens() is not parent.get_tokens() + + assert child.child_requests == [] + + +@force_ampere # Save H100 resource +@pytest.mark.parametrize("n", [2]) +@pytest.mark.parametrize("best_of", [None, 3]) +@pytest.mark.threadleak(enabled=False) +def test_n_outputs(n: int, best_of: int, llm, input_prompts, expected_outputs): + sampling_params = SamplingParams( + n=n, + best_of=best_of, + temperature=0.8, # ensure different outputs + top_p=0.95, # ensure different outputs + use_beam_search=False) + for output_idx, output in enumerate( + llm.generate(input_prompts, sampling_params=sampling_params)): + + assert len(output.outputs) == n + + for idx, sequence in enumerate(output.outputs): + if n == best_of: + assert similar(sequence.text, + expected_outputs[input_prompts[output_idx]][idx]) + + +@pytest.mark.parametrize("n", [3]) +@pytest.mark.threadleak(enabled=False) +def test_async_n_outputs(n: int, llm, input_prompts): + sampling_params = SamplingParams( + n=n, + temperature=0.8, # ensure different outputs + top_p=0.95, # ensure different outputs + use_beam_search=False) + + # Asynchronously submit many requests to exceed max batch size. + futures = [] + for _ in range(5): + for prompt in input_prompts: + future = llm.generate_async(prompt, sampling_params) + futures.append(future) + + # Expect no error raised and each result contains n outputs. + for _, future in enumerate(futures): + request_output = future.result() + assert len(request_output.outputs) == n diff --git a/tests/unittest/_torch/test_executor_request_queue.py b/tests/unittest/_torch/test_executor_request_queue.py index 586a355f159..fdcd304b412 100644 --- a/tests/unittest/_torch/test_executor_request_queue.py +++ b/tests/unittest/_torch/test_executor_request_queue.py @@ -75,7 +75,8 @@ def test_enqueue_requests(executor_queue): """Test enqueuing multiple requests.""" mock_requests = [Mock(), Mock(), Mock()] - with patch('time.time', return_value=1234.5): + with (patch('time.time', return_value=1234.5), + patch.object(executor_queue, '_generate_child_request_ids')): req_ids = executor_queue.enqueue_requests(mock_requests) # type: ignore assert len(req_ids) == 3 @@ -92,7 +93,8 @@ def test_enqueue_request_single(executor_queue): """Test enqueuing a single request.""" mock_request = Mock() - with patch('time.time', return_value=1234.5): + with (patch('time.time', return_value=1234.5), + patch.object(executor_queue, '_generate_child_request_ids')): req_id = executor_queue.enqueue_request(mock_request) assert req_id == 8 @@ -104,8 +106,8 @@ def test_enqueue_request_with_query(executor_queue): """Test enqueuing a request with query data.""" mock_request = Mock() query_data = [1, 2, 3, 4] - - req_id = executor_queue.enqueue_request(mock_request, query=query_data) + with patch.object(executor_queue, '_generate_child_request_ids'): + req_id = executor_queue.enqueue_request(mock_request, query=query_data) assert req_id == 8 @@ -115,6 +117,31 @@ def test_enqueue_request_with_query(executor_queue): assert item.request == mock_request +@pytest.mark.parametrize("n_children", [0, 1, 2]) +def test_enqueue_request_with_child_ids(executor_queue, n_children): + """Test enqueuing a request with query data.""" + mock_request = Mock() + query_data = [1, 2, 3, 4] + with patch.object(executor_queue, + '_get_num_child_requests') as mock_children: + mock_children.return_value = n_children + req_id = executor_queue.enqueue_request(mock_request, query=query_data) + + assert req_id == 8 + + # Verify the item was enqueued with child ids + item = executor_queue.request_queue.get_nowait() + assert item.id == req_id + assert item.request == mock_request + if n_children == 0: + assert item.child_req_ids is None + else: + assert item.child_req_ids is not None + assert len(item.child_req_ids) == n_children + assert item.child_req_ids == list( + range(1 + req_id, 1 + req_id + n_children)) + + def test_enqueue_cancel_request(executor_queue): """Test enqueuing a cancel request.""" req_id = 42 @@ -253,11 +280,10 @@ def test_validate_and_filter_requests(executor_queue): ) def test_merge_requests_default(mock_convert, executor_queue): """Test merging requests with default configuration.""" - mock_llm_request = Mock() + mock_llm_request = Mock(child_requests=[]) mock_convert.return_value = mock_llm_request requests = [RequestQueueItem(1, Mock()), RequestQueueItem(2, Mock())] - result = executor_queue._merge_requests(requests) assert len(result) == 2