From 214fd64ab86b8941a61909c5b7a517f4e8335190 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 16 Oct 2024 09:11:10 -0700 Subject: [PATCH] integrate shortfin_apps llm test into pytest suite --- shortfin/python/shortfin_apps/llm/run_test.sh | 75 ----------- .../shortfin_apps/llm/test_llm_server.py | 117 ++++++++++++++++++ 2 files changed, 117 insertions(+), 75 deletions(-) delete mode 100644 shortfin/python/shortfin_apps/llm/run_test.sh create mode 100644 shortfin/python/shortfin_apps/llm/test_llm_server.py diff --git a/shortfin/python/shortfin_apps/llm/run_test.sh b/shortfin/python/shortfin_apps/llm/run_test.sh deleted file mode 100644 index 18016dcec..000000000 --- a/shortfin/python/shortfin_apps/llm/run_test.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -set -xeuo pipefail - -mkdir -p /tmp/sharktank/llama - -huggingface-cli download --local-dir /tmp/sharktank/llama SlyEcho/open_llama_3b_v2_gguf open-llama-3b-v2-f16.gguf - -HUGGING_FACE_TOKENIZER="openlm-research/open_llama_3b_v2" - -python - < /tmp/sharktank/llama/edited_config.json << EOF -{ - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": [ - 4 - ], - "decode_batch_sizes": [ - 4 - ], - "transformer_block_count": 26, - "paged_kv_cache": { - "block_seq_stride": 16, - "device_block_count": 256 - } -} -EOF - -# Start the server in the background and save its PID -python -m shortfin_apps.llm.server \ - --tokenizer=/tmp/sharktank/llama/tokenizer.json \ - --model_config=/tmp/sharktank/llama/edited_config.json \ - --vmfb=/tmp/sharktank/llama/model.vmfb \ - --parameters=/tmp/sharktank/llama/open-llama-3b-v2-f16.gguf \ - --device=hip & - -SERVER_PID=$! - -# Wait a bit for the server to start up -sleep 5 - -# Run the client -python client.py - -# Kill the server -kill $SERVER_PID - -# Wait for the server to shut down -wait $SERVER_PID diff --git a/shortfin/python/shortfin_apps/llm/test_llm_server.py b/shortfin/python/shortfin_apps/llm/test_llm_server.py new file mode 100644 index 000000000..51474333f --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/test_llm_server.py @@ -0,0 +1,117 @@ +import pytest +import subprocess +import time +import requests +import os +import json + + +@pytest.fixture(scope="module") +def setup_environment(): + # Create necessary directories + os.makedirs("/tmp/sharktank/llama", exist_ok=True) + + # Download model if it doesn't exist + model_path = "/tmp/sharktank/llama/open-llama-3b-v2-f16.gguf" + if not os.path.exists(model_path): + subprocess.run( + "huggingface-cli download --local-dir /tmp/sharktank/llama SlyEcho/open_llama_3b_v2_gguf open-llama-3b-v2-f16.gguf", + shell=True, + check=True, + ) + + # Set up tokenizer if it doesn't exist + tokenizer_path = "/tmp/sharktank/llama/tokenizer.json" + if not os.path.exists(tokenizer_path): + tokenizer_setup = """ +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_3b_v2") +tokenizer.save_pretrained("/tmp/sharktank/llama") +""" + subprocess.run(["python", "-c", tokenizer_setup], check=True) + + # Export model if it doesn't exist + mlir_path = "/tmp/sharktank/llama/model.mlir" + config_path = "/tmp/sharktank/llama/config.json" + if not os.path.exists(mlir_path) or not os.path.exists(config_path): + subprocess.run( + [ + "python", + "-m", + "sharktank.examples.export_paged_llm_v1", + f"--gguf-file={model_path}", + f"--output-mlir={mlir_path}", + f"--output-config={config_path}", + ], + check=True, + ) + + # Compile model if it doesn't exist + vmfb_path = "/tmp/sharktank/llama/model.vmfb" + if not os.path.exists(vmfb_path): + subprocess.run( + [ + "iree-compile", + mlir_path, + "--iree-hal-target-backends=rocm", + "--iree-hip-target=gfx1100", + "-o", + vmfb_path, + ], + check=True, + ) + + # Write config if it doesn't exist + edited_config_path = "/tmp/sharktank/llama/edited_config.json" + if not os.path.exists(edited_config_path): + config = { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 2048, + "attn_head_count": 32, + "attn_head_dim": 100, + "prefill_batch_sizes": [4], + "decode_batch_sizes": [4], + "transformer_block_count": 26, + "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + } + with open(edited_config_path, "w") as f: + json.dump(config, f) + + +@pytest.fixture(scope="module") +def llm_server(setup_environment): + # Start the server + server_process = subprocess.Popen( + [ + "python", + "-m", + "shortfin_apps.llm.server", + "--tokenizer=/tmp/sharktank/llama/tokenizer.json", + "--model_config=/tmp/sharktank/llama/edited_config.json", + "--vmfb=/tmp/sharktank/llama/model.vmfb", + "--parameters=/tmp/sharktank/llama/open-llama-3b-v2-f16.gguf", + "--device=hip", + ] + ) + + # Wait for server to start + time.sleep(5) + + yield server_process + + # Teardown: kill the server + server_process.terminate() + server_process.wait() + + +def test_llm_server(llm_server): + # Here you would typically make requests to your server + # and assert on the responses + # For example: + # response = requests.post("http://localhost:8000/generate", json={"prompt": "Hello, world!"}) + # assert response.status_code == 200 + # assert "generated_text" in response.json() + + # For now, we'll just check if the server process is running + assert llm_server.poll() is None