| 
14 | 14 | # limitations under the License.  | 
15 | 15 | 
 
  | 
16 | 16 | import os  | 
 | 17 | +import re  | 
17 | 18 | import subprocess  | 
 | 19 | +import tempfile  | 
18 | 20 | 
 
  | 
19 | 21 | import pytest  | 
20 |  | -from defs.conftest import skip_arm, skip_no_hopper  | 
21 |  | -from defs.trt_test_alternative import check_call, popen  | 
 | 22 | +import yaml  | 
 | 23 | +from defs.conftest import llm_models_root, skip_arm, skip_no_hopper  | 
 | 24 | +from defs.trt_test_alternative import check_call, check_output, popen  | 
22 | 25 | 
 
  | 
23 | 26 | from tensorrt_llm.logger import logger  | 
24 | 27 | 
 
  | 
@@ -1051,3 +1054,227 @@ def test_disaggregated_deepseek_v3_lite_fp8_tp1_two_mtp(  | 
1051 | 1054 |                            "deepseek_v3_lite_fp8_tp1_two_mtp",  | 
1052 | 1055 |                            env=llm_venv._new_env,  | 
1053 | 1056 |                            cwd=llm_venv.get_working_directory())  | 
 | 1057 | + | 
 | 1058 | + | 
 | 1059 | +@pytest.fixture(scope="module")  | 
 | 1060 | +def benchmark_root():  | 
 | 1061 | +    llm_root = os.getenv("LLM_ROOT")  | 
 | 1062 | +    return os.path.join(llm_root, "tensorrt_llm", "serve", "scripts")  | 
 | 1063 | + | 
 | 1064 | + | 
 | 1065 | +@pytest.fixture(scope="module")  | 
 | 1066 | +def shared_gpt_path():  | 
 | 1067 | +    DEFAULT_LLM_MODEL_ROOT = os.path.join("/scratch.trt_llm_data", "llm-models")  | 
 | 1068 | +    LLM_MODELS_ROOT = os.environ.get("LLM_MODELS_ROOT", DEFAULT_LLM_MODEL_ROOT)  | 
 | 1069 | +    return os.path.join(LLM_MODELS_ROOT, "datasets",  | 
 | 1070 | +                        "ShareGPT_V3_unfiltered_cleaned_split.json")  | 
 | 1071 | + | 
 | 1072 | + | 
 | 1073 | +@pytest.fixture(scope="function")  | 
 | 1074 | +def benchmark_model_root(request):  | 
 | 1075 | +    models_root = llm_models_root()  | 
 | 1076 | +    if (request.param == "DeepSeek-V3-Lite-fp8"):  | 
 | 1077 | +        model_path = os.path.join(models_root, "DeepSeek-V3-Lite", "fp8")  | 
 | 1078 | +    elif (request.param == "DeepSeek-V3-Lite-bf16"):  | 
 | 1079 | +        model_path = os.path.join(models_root, "DeepSeek-V3-Lite", "bf16")  | 
 | 1080 | +    elif request.param == "llama-v3-8b-hf":  | 
 | 1081 | +        model_path = os.path.join(models_root, "llama-models-v3", "8B")  | 
 | 1082 | +    elif request.param == "llama-3.1-8b-instruct-hf-fp8":  | 
 | 1083 | +        model_path = os.path.join(models_root, "llama-3.1-model",  | 
 | 1084 | +                                  "Llama-3.1-8B-Instruct-FP8")  | 
 | 1085 | +    else:  | 
 | 1086 | +        raise ValueError(f"Failed to find the model: {request.param}")  | 
 | 1087 | +    return model_path  | 
 | 1088 | + | 
 | 1089 | + | 
 | 1090 | +def run_disaggregated_benchmark(example_dir,  | 
 | 1091 | +                                config_file,  | 
 | 1092 | +                                benchmark_root,  | 
 | 1093 | +                                benchmark_model_root,  | 
 | 1094 | +                                shared_gpt_path,  | 
 | 1095 | +                                env=None,  | 
 | 1096 | +                                cwd=None):  | 
 | 1097 | +    """Run disaggregated test with given configuration."""  | 
 | 1098 | +    run_env = env.copy()  | 
 | 1099 | +    run_env["UCX_TLS"] = "^ib"  | 
 | 1100 | +    num_rank = 2  | 
 | 1101 | +    workers_cmd = [  | 
 | 1102 | +        'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',  | 
 | 1103 | +        str(num_rank), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',  | 
 | 1104 | +        config_file  | 
 | 1105 | +    ]  | 
 | 1106 | + | 
 | 1107 | +    server_start_timeout = 900  | 
 | 1108 | +    server_cmd = [  | 
 | 1109 | +        'trtllm-serve', 'disaggregated', '--server_start_timeout',  | 
 | 1110 | +        str(server_start_timeout), '-c', config_file  | 
 | 1111 | +    ]  | 
 | 1112 | +    try:  | 
 | 1113 | +        with (  # Start workers  | 
 | 1114 | +                open('output_workers.log', 'w') as output_workers,  | 
 | 1115 | +                popen(workers_cmd,  | 
 | 1116 | +                      stdout=output_workers,  | 
 | 1117 | +                      stderr=subprocess.STDOUT,  | 
 | 1118 | +                      env=run_env,  | 
 | 1119 | +                      cwd=cwd) as workers_proc,  | 
 | 1120 | +                # Start server  | 
 | 1121 | +                open('output_disagg.log', 'w') as output_disagg,  | 
 | 1122 | +                popen(server_cmd,  | 
 | 1123 | +                      stdout=output_disagg,  | 
 | 1124 | +                      stderr=subprocess.STDOUT,  | 
 | 1125 | +                      env=run_env,  | 
 | 1126 | +                      cwd=cwd) as server_proc):  | 
 | 1127 | +            # Ensure the sever has started  | 
 | 1128 | +            client_dir = f"{example_dir}/clients"  | 
 | 1129 | +            client_cmd = [  | 
 | 1130 | +                'python3', f'{client_dir}/disagg_client.py', '-c',  | 
 | 1131 | +                f'{example_dir}/disagg_config.yaml', '-p',  | 
 | 1132 | +                f'{client_dir}/prompts.json', '--ignore-eos',  | 
 | 1133 | +                '--server-start-timeout',  | 
 | 1134 | +                str(server_start_timeout)  | 
 | 1135 | +            ]  | 
 | 1136 | +            # Warm up  | 
 | 1137 | +            check_call(client_cmd,  | 
 | 1138 | +                       env=env,  | 
 | 1139 | +                       poll_procs=[workers_proc, server_proc])  | 
 | 1140 | +            # Start Benchmark  | 
 | 1141 | +            benchmark_script = os.path.join(benchmark_root,  | 
 | 1142 | +                                            "benchmark_serving.py")  | 
 | 1143 | +            benchmark_cmd = [  | 
 | 1144 | +                'python3',  | 
 | 1145 | +                benchmark_script,  | 
 | 1146 | +                '--model',  | 
 | 1147 | +                benchmark_model_root,  | 
 | 1148 | +                '--tokenizer',  | 
 | 1149 | +                benchmark_model_root,  | 
 | 1150 | +                '--dataset-name',  | 
 | 1151 | +                'random',  | 
 | 1152 | +                '--dataset-path',  | 
 | 1153 | +                shared_gpt_path,  | 
 | 1154 | +                '--random-input-len',  | 
 | 1155 | +                '256',  | 
 | 1156 | +                '--random-output-len',  | 
 | 1157 | +                '64',  | 
 | 1158 | +                '--random-prefix-len',  | 
 | 1159 | +                '0',  | 
 | 1160 | +                '--num-prompts',  | 
 | 1161 | +                '320',  | 
 | 1162 | +                '--max-concurrency',  | 
 | 1163 | +                '32',  | 
 | 1164 | +                '--host',  | 
 | 1165 | +                'localhost',  | 
 | 1166 | +                '--port',  | 
 | 1167 | +                '8000',  | 
 | 1168 | +                '--ignore-eos',  | 
 | 1169 | +                '--no-test-input',  | 
 | 1170 | +                '--percentile-metrics',  | 
 | 1171 | +                'e2el,ttft',  | 
 | 1172 | +            ]  | 
 | 1173 | +            # warm up  | 
 | 1174 | +            check_call(benchmark_cmd, env=env)  | 
 | 1175 | +            output = check_output(benchmark_cmd, env=env)  | 
 | 1176 | +            e2el_pattern = r"Median E2EL \(ms\):\s*(\d+\.?\d*)"  | 
 | 1177 | +            ttft_pattern = r"Median TTFT \(ms\):\s*(\d+\.?\d*)"  | 
 | 1178 | +            e2el_match = re.search(e2el_pattern, output)  | 
 | 1179 | +            ttft_match = re.search(ttft_pattern, output)  | 
 | 1180 | +            if e2el_match and ttft_match:  | 
 | 1181 | +                median_e2el = float(e2el_match.group(1))  | 
 | 1182 | +                median_ttft = float(ttft_match.group(1))  | 
 | 1183 | +                return median_e2el, median_ttft  | 
 | 1184 | +            else:  | 
 | 1185 | +                raise ValueError("No benchmark result found")  | 
 | 1186 | + | 
 | 1187 | +    except Exception:  | 
 | 1188 | +        # Print outputs on error  | 
 | 1189 | +        logger.error("-------- Workers output --------")  | 
 | 1190 | +        with open('output_workers.log', 'r') as f:  | 
 | 1191 | +            logger.error(f.read())  | 
 | 1192 | + | 
 | 1193 | +        logger.error("-------- Disagg server output --------")  | 
 | 1194 | +        with open('output_disagg.log', 'r') as f:  | 
 | 1195 | +            logger.error(f.read())  | 
 | 1196 | +        raise  | 
 | 1197 | +    finally:  | 
 | 1198 | +        server_proc.terminate()  | 
 | 1199 | +        workers_proc.terminate()  | 
 | 1200 | +        server_proc.wait()  | 
 | 1201 | +        workers_proc.wait()  | 
 | 1202 | + | 
 | 1203 | + | 
 | 1204 | +def get_config_for_benchmark(model_root, backend):  | 
 | 1205 | +    serve_config = {  | 
 | 1206 | +        "model": model_root,  | 
 | 1207 | +        "hostname": "localhost",  | 
 | 1208 | +        "port": 8000,  | 
 | 1209 | +        "backend": "pytorch",  | 
 | 1210 | +        "context_servers": {  | 
 | 1211 | +            "num_instances": 1,  | 
 | 1212 | +            "max_batch_size": 2,  | 
 | 1213 | +            "max_num_tokens": 384,  | 
 | 1214 | +            "max_seq_len": 384,  | 
 | 1215 | +            "tensor_parallel_size": 1,  | 
 | 1216 | +            "pipeline_parallel_size": 1,  | 
 | 1217 | +            "disable_overlap_scheduler": True,  | 
 | 1218 | +            "cache_transceiver_config": {  | 
 | 1219 | +                "backend": backend,  | 
 | 1220 | +                "max_tokens_in_buffer": 512,  | 
 | 1221 | +            },  | 
 | 1222 | +            "urls": ["localhost:8001"]  | 
 | 1223 | +        },  | 
 | 1224 | +        "generation_servers": {  | 
 | 1225 | +            "num_instances": 1,  | 
 | 1226 | +            "tensor_parallel_size": 1,  | 
 | 1227 | +            "pipeline_parallel_size": 1,  | 
 | 1228 | +            "max_batch_size": 2,  | 
 | 1229 | +            "max_num_tokens": 384,  | 
 | 1230 | +            "max_seq_len": 384,  | 
 | 1231 | +            "cache_transceiver_config": {  | 
 | 1232 | +                "backend": backend,  | 
 | 1233 | +                "max_tokens_in_buffer": 512,  | 
 | 1234 | +            },  | 
 | 1235 | +            "urls": ["localhost:8002"]  | 
 | 1236 | +        }  | 
 | 1237 | +    }  | 
 | 1238 | +    return serve_config  | 
 | 1239 | + | 
 | 1240 | + | 
 | 1241 | +@pytest.mark.parametrize("benchmark_model_root", [  | 
 | 1242 | +    'DeepSeek-V3-Lite-fp8', 'DeepSeek-V3-Lite-bf16', 'llama-v3-8b-hf',  | 
 | 1243 | +    'llama-3.1-8b-instruct-hf-fp8'  | 
 | 1244 | +],  | 
 | 1245 | +                         indirect=True)  | 
 | 1246 | +def test_disaggregated_benchmark_on_diff_backends(  | 
 | 1247 | +        disaggregated_test_root, disaggregated_example_root, llm_venv,  | 
 | 1248 | +        benchmark_model_root, benchmark_root, shared_gpt_path):  | 
 | 1249 | +    nixl_config = get_config_for_benchmark(benchmark_model_root, "nixl")  | 
 | 1250 | +    ucx_config = get_config_for_benchmark(benchmark_model_root, "ucx")  | 
 | 1251 | +    temp_dir = tempfile.TemporaryDirectory()  | 
 | 1252 | +    nixl_config_path = os.path.join(temp_dir.name, "nixl_config.yaml")  | 
 | 1253 | +    ucx_config_path = os.path.join(temp_dir.name, "ucx_config.yaml")  | 
 | 1254 | +    with open(nixl_config_path, 'w', encoding='utf-8') as f:  | 
 | 1255 | +        yaml.dump(nixl_config, f)  | 
 | 1256 | +    with open(ucx_config_path, 'w', encoding='utf-8') as f:  | 
 | 1257 | +        yaml.dump(ucx_config, f)  | 
 | 1258 | + | 
 | 1259 | +    env = llm_venv._new_env.copy()  | 
 | 1260 | +    nixl_e2el, nixl_ttft = run_disaggregated_benchmark(  | 
 | 1261 | +        disaggregated_example_root,  | 
 | 1262 | +        nixl_config_path,  | 
 | 1263 | +        benchmark_root,  | 
 | 1264 | +        benchmark_model_root,  | 
 | 1265 | +        shared_gpt_path,  | 
 | 1266 | +        env=env,  | 
 | 1267 | +        cwd=llm_venv.get_working_directory())  | 
 | 1268 | +    ucx_e2el, ucx_ttft = run_disaggregated_benchmark(  | 
 | 1269 | +        disaggregated_example_root,  | 
 | 1270 | +        ucx_config_path,  | 
 | 1271 | +        benchmark_root,  | 
 | 1272 | +        benchmark_model_root,  | 
 | 1273 | +        shared_gpt_path,  | 
 | 1274 | +        env=env,  | 
 | 1275 | +        cwd=llm_venv.get_working_directory())  | 
 | 1276 | +    print(f"Nixl E2EL: {nixl_e2el} ms, UCX E2EL: {ucx_e2el} ms")  | 
 | 1277 | +    print(f"Nixl TTFT: {nixl_ttft} ms, UCX TTFT: {ucx_ttft} ms")  | 
 | 1278 | + | 
 | 1279 | +    assert ucx_e2el > 0 and nixl_e2el > 0 and nixl_e2el < 1.05 * ucx_e2el  | 
 | 1280 | +    assert ucx_ttft > 0 and nixl_ttft > 0 and nixl_ttft < 1.05 * ucx_ttft  | 
0 commit comments