Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device selection to shortfin llm demo #275

Merged
merged 6 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions shortfin/python/shortfin_apps/llm/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import requests
import json
import uuid

BASE_URL = "http://localhost:8000"


def test_health():
response = requests.get(f"{BASE_URL}/health")
print(f"Health check status code: {response.status_code}")
return response.status_code == 200


def test_generate():
headers = {"Content-Type": "application/json"}

# Create a GenerateReqInput-like structure
data = {
"text": "1 2 3 4 5",
"sampling_params": {"max_tokens": 50, "temperature": 0.7},
"rid": uuid.uuid4().hex,
"return_logprob": False,
"logprob_start_len": -1,
"top_logprobs_num": 0,
"return_text_in_logprobs": False,
"stream": False,
}

print("Prompt text:")
print(data["text"])

response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data)
print(f"Generate endpoint status code: {response.status_code}")

if response.status_code == 200:
print("Generated text:")
data = response.text
assert data.startswith("data: ")
data = data[6:]
assert data.endswith("\n\n")
data = data[:-2]
print(data)
else:
print("Failed to generate text")
print("Response content:")
print(response.text)

return response.status_code == 200


def main():
print("Testing webapp...")

health_ok = test_health()
generate_ok = test_generate()

if health_ok and generate_ok:
print("\nAll tests passed successfully!")
else:
print("\nSome tests failed. Please check the output above for details.")


if __name__ == "__main__":
main()
7 changes: 5 additions & 2 deletions shortfin/python/shortfin_apps/llm/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@


class SystemManager:
def __init__(self):
self.ls = sf.host.CPUSystemBuilder().create_system()
def __init__(self, device="local-task"):
if device == "local-task":
self.ls = sf.host.CPUSystemBuilder().create_system()
elif device == "hip":
self.ls = sf.amdgpu.SystemBuilder().create_system()
logger.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
Expand Down
12 changes: 10 additions & 2 deletions shortfin/python/shortfin_apps/llm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def generate_request(gen_req: GenerateReqInput, request: Request):

def configure(args) -> SystemManager:
# Setup system (configure devices, etc).
sysman = SystemManager()
sysman = SystemManager(device=args.device)

# Setup each service we are hosting.
tokenizer = Tokenizer.from_tokenizer_json_file(args.tokenizer_json)
Expand Down Expand Up @@ -121,11 +121,19 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
required=True,
help="Model VMFB to load",
)
# parameters are loaded with `iree_io_parameters_module_create`
parser.add_argument(
"--parameters",
type=Path,
nargs="*",
help="Parameter archives to load",
help="Parameter archives to load (supports: gguf, irpa, safetensors).",
metavar="FILE",
)
parser.add_argument(
"--device",
type=str,
default="local-task",
help="Device to serve on; e.g. local-task, hip. Same options as `iree-run-module --device` ",
)
args = parser.parse_args(argv)
global sysman
Expand Down
117 changes: 117 additions & 0 deletions shortfin/python/shortfin_apps/llm/test_llm_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import pytest
import subprocess
import time
import requests
import os
import json


@pytest.fixture(scope="module")
def setup_environment():
# Create necessary directories
os.makedirs("/tmp/sharktank/llama", exist_ok=True)

# Download model if it doesn't exist
model_path = "/tmp/sharktank/llama/open-llama-3b-v2-f16.gguf"
if not os.path.exists(model_path):
subprocess.run(
"huggingface-cli download --local-dir /tmp/sharktank/llama SlyEcho/open_llama_3b_v2_gguf open-llama-3b-v2-f16.gguf",
shell=True,
check=True,
)

# Set up tokenizer if it doesn't exist
tokenizer_path = "/tmp/sharktank/llama/tokenizer.json"
if not os.path.exists(tokenizer_path):
tokenizer_setup = """
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_3b_v2")
tokenizer.save_pretrained("/tmp/sharktank/llama")
"""
subprocess.run(["python", "-c", tokenizer_setup], check=True)

# Export model if it doesn't exist
mlir_path = "/tmp/sharktank/llama/model.mlir"
config_path = "/tmp/sharktank/llama/config.json"
if not os.path.exists(mlir_path) or not os.path.exists(config_path):
subprocess.run(
[
"python",
"-m",
"sharktank.examples.export_paged_llm_v1",
f"--gguf-file={model_path}",
f"--output-mlir={mlir_path}",
f"--output-config={config_path}",
],
check=True,
)

# Compile model if it doesn't exist
vmfb_path = "/tmp/sharktank/llama/model.vmfb"
if not os.path.exists(vmfb_path):
subprocess.run(
[
"iree-compile",
mlir_path,
"--iree-hal-target-backends=rocm",
"--iree-hip-target=gfx1100",
"-o",
vmfb_path,
],
check=True,
)

# Write config if it doesn't exist
edited_config_path = "/tmp/sharktank/llama/edited_config.json"
if not os.path.exists(edited_config_path):
config = {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 2048,
"attn_head_count": 32,
"attn_head_dim": 100,
"prefill_batch_sizes": [4],
"decode_batch_sizes": [4],
"transformer_block_count": 26,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
}
with open(edited_config_path, "w") as f:
json.dump(config, f)


@pytest.fixture(scope="module")
def llm_server(setup_environment):
# Start the server
server_process = subprocess.Popen(
[
"python",
"-m",
"shortfin_apps.llm.server",
"--tokenizer=/tmp/sharktank/llama/tokenizer.json",
"--model_config=/tmp/sharktank/llama/edited_config.json",
"--vmfb=/tmp/sharktank/llama/model.vmfb",
"--parameters=/tmp/sharktank/llama/open-llama-3b-v2-f16.gguf",
"--device=hip",
]
)

# Wait for server to start
time.sleep(5)

yield server_process

# Teardown: kill the server
server_process.terminate()
server_process.wait()


def test_llm_server(llm_server):
# Here you would typically make requests to your server
# and assert on the responses
# For example:
# response = requests.post("http://localhost:8000/generate", json={"prompt": "Hello, world!"})
# assert response.status_code == 200
# assert "generated_text" in response.json()

# For now, we'll just check if the server process is running
assert llm_server.poll() is None
Loading