Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move SGLang related tests #601

Merged
merged 4 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci-sglang-benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
runs-on: llama-mi300x-3
runs-on: mi300x-4
defaults:
run:
shell: bash
Expand Down Expand Up @@ -78,7 +78,7 @@ jobs:
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"

- name: Launch Shortfin Server
run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html
run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-sglang-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
runs-on: llama-mi300x-3
runs-on: mi300x-4
defaults:
run:
shell: bash
Expand Down
5 changes: 5 additions & 0 deletions app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,32 @@
import pytest
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from integration_tests.llm.utils import compile_model, export_paged_llm_v1
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
)
from integration_tests.llm.utils import (
compile_model,
export_paged_llm_v1,
download_with_hf_datasets,
)


@pytest.fixture(scope="module")
def pre_process_model(request, tmp_path_factory):
tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test")

model_path = request.param["model_path"]
model_name = request.param["model_name"]
model_param_file_name = request.param["model_param_file_name"]
settings = request.param["settings"]
batch_sizes = request.param["batch_sizes"]

mlir_path = tmp_dir / "model.mlir"
config_path = tmp_dir / "config.json"
vmfb_path = tmp_dir / "model.vmfb"

model_path = tmp_dir / model_param_file_name
download_with_hf_datasets(tmp_dir, model_name)

export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)

config = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import json
import logging
import multiprocessing
import os
Expand All @@ -16,14 +15,14 @@
pytest.importorskip("sglang")
from sglang import bench_serving

from utils import SGLangBenchmarkArgs
from .utils import SGLangBenchmarkArgs, log_jsonl_result

from integration_tests.llm.utils import (
find_available_port,
start_llm_server,
)

logger = logging.getLogger("__name__")
logger = logging.getLogger(__name__)

device_settings = {
"device_flags": [
Expand All @@ -33,46 +32,40 @@
"device": "hip",
}

# TODO: Download on demand instead of assuming files exist at this path
MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa")
TOKENIZER_DIR = Path("/data/llama3.1/8b/")


def log_jsonl_result(file_path):
with open(file_path, "r") as file:
json_string = file.readline().strip()

json_data = json.loads(json_string)
for key, val in json_data.items():
logger.info(f"{key.upper()}: {val}")


@pytest.mark.parametrize(
"request_rate",
[1, 2, 4, 8, 16, 32],
"request_rate,model_param_file_name",
[
(req_rate, "meta-llama-3.1-8b-instruct.f16.gguf")
for req_rate in [1, 2, 4, 8, 16, 32]
],
)
@pytest.mark.parametrize(
"pre_process_model",
[
(
{
"model_path": MODEL_PATH,
"model_name": "llama3_8B_fp16",
"model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf",
"settings": device_settings,
"batch_sizes": [1, 4],
}
)
],
indirect=True,
)
def test_sglang_benchmark_server(request_rate, pre_process_model):
def test_sglang_benchmark_server(
request_rate, model_param_file_name, pre_process_model
):
# TODO: Remove when multi-device is fixed
os.environ["ROCR_VISIBLE_DEVICES"] = "1"

tmp_dir = pre_process_model

config_path = tmp_dir / "config.json"
vmfb_path = tmp_dir / "model.vmfb"
tokenizer_path = TOKENIZER_DIR / "tokenizer.json"
tokenizer_path = tmp_dir / "tokenizer.json"
model_path = tmp_dir / model_param_file_name

# Start shortfin llm server
port = find_available_port()
Expand All @@ -81,7 +74,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
tokenizer_path,
config_path,
vmfb_path,
MODEL_PATH,
model_path,
device_settings,
timeout=30,
)
Expand All @@ -91,7 +84,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
backend="shortfin",
num_prompt=10,
base_url=f"http://localhost:{port}",
tokenizer=TOKENIZER_DIR,
tokenizer=tmp_dir,
request_rate=request_rate,
)
output_file = (
Expand All @@ -116,7 +109,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
logger.info("======== RESULTS ========")
log_jsonl_result(benchmark_args.output_file)
except Exception as e:
logger.info(e)
logger.error(e)

server_process.terminate()
server_process.wait()
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

from argparse import Namespace
from dataclasses import dataclass
import json
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


@dataclass
class SGLangBenchmarkArgs:
Expand Down Expand Up @@ -54,3 +58,12 @@ def __repr__(self):
f"Tokenizer: {self.tokenizer}\n"
f"Request Rate: {self.request_rate}"
)


def log_jsonl_result(file_path):
with open(file_path, "r") as file:
json_string = file.readline().strip()

json_data = json.loads(json_string)
for key, val in json_data.items():
logger.info(f"{key.upper()}: {val}")
2 changes: 1 addition & 1 deletion app_tests/integration_tests/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import requests
from transformers import AutoTokenizer

logger = logging.getLogger("__name__")
logger = logging.getLogger(__name__)


class AccuracyValidationException(RuntimeError):
Expand Down
Loading