Skip to content

Commit

Permalink
Move SGLang related tests (#601)
Browse files Browse the repository at this point in the history
Split from this PR: #590

We have too many tests running on `mi300x-3` and need to move the SGLang
related ones to `mi300x-4`.

This PR moves the workflows for `sglang_integration_tests` and
`sglang_benchmark_tests` to mi300x-4, along with removing the assumption
of static MODEL_PATH and TOKENIZER_PATH, downloading them on demand
instead.
  • Loading branch information
stbaione authored Nov 25, 2024
1 parent eacbd9b commit 9d9f0d3
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 31 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci-sglang-benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
runs-on: llama-mi300x-3
runs-on: mi300x-4
defaults:
run:
shell: bash
Expand Down Expand Up @@ -78,7 +78,7 @@ jobs:
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"

- name: Launch Shortfin Server
run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html
run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-sglang-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
runs-on: llama-mi300x-3
runs-on: mi300x-4
defaults:
run:
shell: bash
Expand Down
5 changes: 5 additions & 0 deletions app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,32 @@
import pytest
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from integration_tests.llm.utils import compile_model, export_paged_llm_v1
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
)
from integration_tests.llm.utils import (
compile_model,
export_paged_llm_v1,
download_with_hf_datasets,
)


@pytest.fixture(scope="module")
def pre_process_model(request, tmp_path_factory):
tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test")

model_path = request.param["model_path"]
model_name = request.param["model_name"]
model_param_file_name = request.param["model_param_file_name"]
settings = request.param["settings"]
batch_sizes = request.param["batch_sizes"]

mlir_path = tmp_dir / "model.mlir"
config_path = tmp_dir / "config.json"
vmfb_path = tmp_dir / "model.vmfb"

model_path = tmp_dir / model_param_file_name
download_with_hf_datasets(tmp_dir, model_name)

export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)

config = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import json
import logging
import multiprocessing
import os
Expand All @@ -16,14 +15,14 @@
pytest.importorskip("sglang")
from sglang import bench_serving

from utils import SGLangBenchmarkArgs
from .utils import SGLangBenchmarkArgs, log_jsonl_result

from integration_tests.llm.utils import (
find_available_port,
start_llm_server,
)

logger = logging.getLogger("__name__")
logger = logging.getLogger(__name__)

device_settings = {
"device_flags": [
Expand All @@ -33,46 +32,40 @@
"device": "hip",
}

# TODO: Download on demand instead of assuming files exist at this path
MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa")
TOKENIZER_DIR = Path("/data/llama3.1/8b/")


def log_jsonl_result(file_path):
with open(file_path, "r") as file:
json_string = file.readline().strip()

json_data = json.loads(json_string)
for key, val in json_data.items():
logger.info(f"{key.upper()}: {val}")


@pytest.mark.parametrize(
"request_rate",
[1, 2, 4, 8, 16, 32],
"request_rate,model_param_file_name",
[
(req_rate, "meta-llama-3.1-8b-instruct.f16.gguf")
for req_rate in [1, 2, 4, 8, 16, 32]
],
)
@pytest.mark.parametrize(
"pre_process_model",
[
(
{
"model_path": MODEL_PATH,
"model_name": "llama3_8B_fp16",
"model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf",
"settings": device_settings,
"batch_sizes": [1, 4],
}
)
],
indirect=True,
)
def test_sglang_benchmark_server(request_rate, pre_process_model):
def test_sglang_benchmark_server(
request_rate, model_param_file_name, pre_process_model
):
# TODO: Remove when multi-device is fixed
os.environ["ROCR_VISIBLE_DEVICES"] = "1"

tmp_dir = pre_process_model

config_path = tmp_dir / "config.json"
vmfb_path = tmp_dir / "model.vmfb"
tokenizer_path = TOKENIZER_DIR / "tokenizer.json"
tokenizer_path = tmp_dir / "tokenizer.json"
model_path = tmp_dir / model_param_file_name

# Start shortfin llm server
port = find_available_port()
Expand All @@ -81,7 +74,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
tokenizer_path,
config_path,
vmfb_path,
MODEL_PATH,
model_path,
device_settings,
timeout=30,
)
Expand All @@ -91,7 +84,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
backend="shortfin",
num_prompt=10,
base_url=f"http://localhost:{port}",
tokenizer=TOKENIZER_DIR,
tokenizer=tmp_dir,
request_rate=request_rate,
)
output_file = (
Expand All @@ -116,7 +109,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
logger.info("======== RESULTS ========")
log_jsonl_result(benchmark_args.output_file)
except Exception as e:
logger.info(e)
logger.error(e)

server_process.terminate()
server_process.wait()
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

from argparse import Namespace
from dataclasses import dataclass
import json
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


@dataclass
class SGLangBenchmarkArgs:
Expand Down Expand Up @@ -54,3 +58,12 @@ def __repr__(self):
f"Tokenizer: {self.tokenizer}\n"
f"Request Rate: {self.request_rate}"
)


def log_jsonl_result(file_path):
with open(file_path, "r") as file:
json_string = file.readline().strip()

json_data = json.loads(json_string)
for key, val in json_data.items():
logger.info(f"{key.upper()}: {val}")
2 changes: 1 addition & 1 deletion app_tests/integration_tests/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import requests
from transformers import AutoTokenizer

logger = logging.getLogger("__name__")
logger = logging.getLogger(__name__)


class AccuracyValidationException(RuntimeError):
Expand Down

0 comments on commit 9d9f0d3

Please sign in to comment.