Skip to content

Commit d101a6c

Browse files
authored
[https://nvbugs/5410279][test] resubmit timeout refactor (#6337)
Signed-off-by: Ivy Zhang <[email protected]>
1 parent 7cbe30e commit d101a6c

File tree

11 files changed

+611
-231
lines changed

11 files changed

+611
-231
lines changed

tests/integration/defs/accuracy/accuracy_core.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -701,26 +701,59 @@ def run(self,
701701
extra_build_args: Optional[list] = None,
702702
extra_summarize_args: Optional[list] = None,
703703
extra_eval_long_context_args: Optional[list] = None,
704-
env: Optional[Dict[str, str]] = None):
705-
self.install_requirements()
706-
self.initialize_case(
707-
tasks=tasks,
708-
dtype=dtype,
709-
quant_algo=quant_algo,
710-
kv_cache_quant_algo=kv_cache_quant_algo,
711-
spec_dec_algo=spec_dec_algo,
712-
extra_acc_spec=extra_acc_spec,
713-
tp_size=tp_size,
714-
pp_size=pp_size,
715-
cp_size=cp_size,
716-
extra_convert_args=extra_convert_args,
717-
extra_build_args=extra_build_args,
718-
extra_summarize_args=extra_summarize_args,
719-
extra_eval_long_context_args=extra_eval_long_context_args,
720-
env=env)
721-
self.convert()
722-
self.build()
723-
self.evaluate()
704+
env: Optional[Dict[str, str]] = None,
705+
timeout_manager=None):
706+
"""
707+
Run all accuracy test phases with timeout management.
708+
If timeout_manager is provided, each phase will be wrapped to track and deduct remaining timeout.
709+
"""
710+
# Use timeout_manager to manage timeout for each phase
711+
if timeout_manager is not None:
712+
with timeout_manager.timed_operation("install_requirements"):
713+
self.install_requirements()
714+
with timeout_manager.timed_operation("initialize_case"):
715+
self.initialize_case(
716+
tasks=tasks,
717+
dtype=dtype,
718+
quant_algo=quant_algo,
719+
kv_cache_quant_algo=kv_cache_quant_algo,
720+
spec_dec_algo=spec_dec_algo,
721+
extra_acc_spec=extra_acc_spec,
722+
tp_size=tp_size,
723+
pp_size=pp_size,
724+
cp_size=cp_size,
725+
extra_convert_args=extra_convert_args,
726+
extra_build_args=extra_build_args,
727+
extra_summarize_args=extra_summarize_args,
728+
extra_eval_long_context_args=extra_eval_long_context_args,
729+
env=env)
730+
with timeout_manager.timed_operation("convert"):
731+
self.convert()
732+
with timeout_manager.timed_operation("build"):
733+
self.build()
734+
with timeout_manager.timed_operation("evaluate"):
735+
self.evaluate()
736+
else:
737+
# fallback: no timeout management
738+
self.install_requirements()
739+
self.initialize_case(
740+
tasks=tasks,
741+
dtype=dtype,
742+
quant_algo=quant_algo,
743+
kv_cache_quant_algo=kv_cache_quant_algo,
744+
spec_dec_algo=spec_dec_algo,
745+
extra_acc_spec=extra_acc_spec,
746+
tp_size=tp_size,
747+
pp_size=pp_size,
748+
cp_size=cp_size,
749+
extra_convert_args=extra_convert_args,
750+
extra_build_args=extra_build_args,
751+
extra_summarize_args=extra_summarize_args,
752+
extra_eval_long_context_args=extra_eval_long_context_args,
753+
env=env)
754+
self.convert()
755+
self.build()
756+
self.evaluate()
724757

725758

726759
class LlmapiAccuracyTestHarness:

tests/integration/defs/accuracy/test_cli_flow.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,14 +1167,15 @@ class TestMixtral8x22B(CliFlowAccuracyTestHarness):
11671167
@skip_pre_ada
11681168
@pytest.mark.skip_less_device(4)
11691169
@pytest.mark.skip_less_device_memory(80000)
1170-
def test_fp8_tp2pp2(self):
1170+
def test_fp8_tp2pp2(self, timeout_manager):
11711171
self.run(tasks=[CnnDailymail(self.MODEL_NAME),
11721172
MMLU(self.MODEL_NAME)],
11731173
quant_algo=QuantAlgo.FP8,
11741174
tp_size=2,
11751175
pp_size=2,
11761176
extra_convert_args=["--calib_size=32"],
1177-
extra_build_args=["--gemm_plugin=auto"])
1177+
extra_build_args=["--gemm_plugin=auto"],
1178+
timeout_manager=timeout_manager)
11781179

11791180
@skip_post_blackwell
11801181
@pytest.mark.skip_less_device(8)
@@ -1184,7 +1185,8 @@ def test_fp8_tp2pp2(self):
11841185
ids=['expert_parallel', 'mixed_parallel', 'tensor_parallel'])
11851186
@pytest.mark.parametrize("moe_renorm_mode", [0, 1],
11861187
ids=['no_renormalize', 'renormalize'])
1187-
def test_int8_plugin_tp8(self, moe_tp_size, moe_renorm_mode):
1188+
def test_int8_plugin_tp8(self, moe_tp_size, moe_renorm_mode,
1189+
timeout_manager):
11881190
self.run(quant_algo=QuantAlgo.W8A16,
11891191
tp_size=8,
11901192
extra_convert_args=[
@@ -1195,7 +1197,8 @@ def test_int8_plugin_tp8(self, moe_tp_size, moe_renorm_mode):
11951197
extra_build_args=[
11961198
"--max_beam_width=4", "--gemm_plugin=auto",
11971199
"--moe_plugin=auto", f"--max_seq_len={8192}"
1198-
])
1200+
],
1201+
timeout_manager=timeout_manager)
11991202

12001203

12011204
class TestGemma2B(CliFlowAccuracyTestHarness):

tests/integration/defs/common.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _war_check_output(*args, **kwargs):
4444
return venv.run_cmd(cmd, caller=_war_check_output, env=env, **kwargs)
4545

4646

47-
def venv_mpi_check_call(venv, mpi_cmd, python_cmd):
47+
def venv_mpi_check_call(venv, mpi_cmd, python_cmd, **kwargs):
4848
"""
4949
This function WAR check_call() to run python_cmd with mpi.
5050
If mpi_cmd = ["mpirun", "-n", "2"] and python_cmd = ["run.py"], the command will be:
@@ -61,10 +61,10 @@ def _war_check_call(*args, **kwargs):
6161
kwargs["cwd"] = venv.get_working_directory()
6262
return check_call(merged_cmd, **kwargs)
6363

64-
venv.run_cmd(python_cmd, caller=_war_check_call)
64+
venv.run_cmd(python_cmd, caller=_war_check_call, **kwargs)
6565

6666

67-
def venv_mpi_check_output(venv, mpi_cmd, python_cmd, env=None):
67+
def venv_mpi_check_output(venv, mpi_cmd, python_cmd, env=None, **kwargs):
6868
"""
6969
This function WAR check_output() to run python_cmd with mpi.
7070
If mpi_cmd = ["mpirun", "-n", "2"] and python_cmd = ["run.py"], the command will be:
@@ -81,7 +81,7 @@ def _war_check_output(*args, **kwargs):
8181
kwargs["cwd"] = venv.get_working_directory()
8282
return check_output(merged_cmd, **kwargs)
8383

84-
return venv.run_cmd(python_cmd, caller=_war_check_output, env=env)
84+
return venv.run_cmd(python_cmd, caller=_war_check_output, env=env, **kwargs)
8585

8686

8787
def parse_mpi_cmd(cmd):
@@ -506,6 +506,7 @@ def convert_weights(llm_venv,
506506
convert_cmd.append(f"--quant_ckpt_path={quant_ckpt_path}")
507507
if per_group:
508508
convert_cmd.append("--per_group")
509+
timeout = kwargs.pop('timeout', None)
509510

510511
for key, value in kwargs.items():
511512
if isinstance(value, bool):
@@ -515,7 +516,7 @@ def convert_weights(llm_venv,
515516
convert_cmd.extend([f"--{key}={value}"])
516517

517518
if llm_venv:
518-
venv_check_call(llm_venv, convert_cmd)
519+
venv_check_call(llm_venv, convert_cmd, timeout=timeout)
519520
return model_dir
520521
else:
521522
return convert_cmd, model_dir
@@ -607,6 +608,7 @@ def quantize_data(llm_venv,
607608

608609
if kv_cache_dtype:
609610
quantize_cmd.append(f"--kv_cache_dtype={kv_cache_dtype}")
611+
timeout = kwargs.pop('timeout', None)
610612

611613
for key, value in kwargs.items():
612614
if isinstance(value, bool):
@@ -617,7 +619,7 @@ def quantize_data(llm_venv,
617619

618620
if llm_venv:
619621
if not exists(output_dir):
620-
venv_check_call(llm_venv, quantize_cmd)
622+
venv_check_call(llm_venv, quantize_cmd, timeout=timeout)
621623
return output_dir
622624
else:
623625
return quantize_cmd, output_dir

tests/integration/defs/conftest.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2351,3 +2351,38 @@ def tritonserver_test_root(llm_root):
23512351
"tests/integration/defs/triton_server")
23522352

23532353
return tritonserver_root
2354+
2355+
2356+
@pytest.fixture
2357+
def timeout_from_marker(request):
2358+
"""Get timeout value from pytest timeout marker."""
2359+
timeout_marker = request.node.get_closest_marker('timeout')
2360+
if timeout_marker:
2361+
return timeout_marker.args[0] if timeout_marker.args else None
2362+
return None
2363+
2364+
2365+
@pytest.fixture
2366+
def timeout_from_command_line(request):
2367+
"""Get timeout value from command line --timeout parameter."""
2368+
# Get timeout from command line argument
2369+
timeout_arg = request.config.getoption("--timeout", default=None)
2370+
if timeout_arg is not None:
2371+
return float(timeout_arg)
2372+
return None
2373+
2374+
2375+
@pytest.fixture
2376+
def timeout_manager(timeout_from_command_line, timeout_from_marker):
2377+
"""Create a TimeoutManager instance with priority: marker > cmdline > config."""
2378+
from defs.utils.timeout_manager import TimeoutManager
2379+
2380+
# Priority: marker > command line
2381+
timeout_value = None
2382+
2383+
if timeout_from_marker is not None:
2384+
timeout_value = timeout_from_marker
2385+
elif timeout_from_command_line is not None:
2386+
timeout_value = timeout_from_command_line
2387+
2388+
return TimeoutManager(timeout_value)

tests/integration/defs/examples/test_commandr.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,27 @@ def test_llm_commandr_plus_4gpus_summary(commandr_example_root,
9494
llm_commandr_plus_model_root,
9595
llm_datasets_root, llm_rouge_root,
9696
llm_venv, cmodel_dir, engine_dir,
97-
use_weight_only):
97+
use_weight_only, timeout_manager):
9898
"Build & run Command-R+ with smoothquant on 4 gpus."
9999
dtype = 'float16'
100100
tp_size = 4
101101
model_name = os.path.basename(llm_commandr_plus_model_root)
102-
print("Converting checkpoint...")
103-
ckpt_dir = convert_weights(llm_venv=llm_venv,
104-
example_root=commandr_example_root,
105-
cmodel_dir=cmodel_dir,
106-
model=model_name,
107-
model_path=llm_commandr_plus_model_root,
108-
data_type=dtype,
109-
tp_size=tp_size,
110-
gpus=tp_size,
111-
use_weight_only=use_weight_only)
112102

103+
# Convert checkpoint with timeout management
104+
print("Converting checkpoint...")
105+
with timeout_manager.timed_operation("convert"):
106+
ckpt_dir = convert_weights(llm_venv=llm_venv,
107+
example_root=commandr_example_root,
108+
cmodel_dir=cmodel_dir,
109+
model=model_name,
110+
model_path=llm_commandr_plus_model_root,
111+
data_type=dtype,
112+
tp_size=tp_size,
113+
gpus=tp_size,
114+
use_weight_only=use_weight_only,
115+
timeout=timeout_manager.remaining_timeout)
116+
117+
# Build engines with timeout management
113118
print("Building engines...")
114119
build_cmd = [
115120
"trtllm-build",
@@ -130,12 +135,23 @@ def test_llm_commandr_plus_4gpus_summary(commandr_example_root,
130135
f"--engine_dir={engine_dir}",
131136
]
132137

133-
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
134-
135-
venv_mpi_check_call(
136-
llm_venv,
137-
["mpirun", "-n", str(tp_size), "--allow-run-as-root"], run_cmd)
138-
138+
with timeout_manager.timed_operation("build"):
139+
check_call(" ".join(build_cmd),
140+
shell=True,
141+
env=llm_venv._new_env,
142+
timeout=timeout_manager.remaining_timeout)
143+
144+
# Run engines with timeout management
145+
print("Running engines...")
146+
with timeout_manager.timed_operation("run"):
147+
venv_mpi_check_call(
148+
llm_venv, ["mpirun", "-n",
149+
str(tp_size), "--allow-run-as-root"],
150+
run_cmd,
151+
timeout=timeout_manager.remaining_timeout)
152+
153+
# Run summary with timeout management
154+
print("Running summary...")
139155
summary_cmd = generate_summary_cmd(
140156
commandr_example_root,
141157
hf_model_dir=llm_commandr_plus_model_root,
@@ -144,6 +160,9 @@ def test_llm_commandr_plus_4gpus_summary(commandr_example_root,
144160
dataset_dir=llm_datasets_root,
145161
rouge_dir=llm_rouge_root)
146162

147-
venv_mpi_check_call(
148-
llm_venv,
149-
["mpirun", "-n", str(tp_size), "--allow-run-as-root"], summary_cmd)
163+
with timeout_manager.timed_operation("summary"):
164+
venv_mpi_check_call(
165+
llm_venv, ["mpirun", "-n",
166+
str(tp_size), "--allow-run-as-root"],
167+
summary_cmd,
168+
timeout=timeout_manager.remaining_timeout)

0 commit comments

Comments
 (0)