diff --git a/example.py b/example.py index 51c7b79..ae5fb13 100644 --- a/example.py +++ b/example.py @@ -7,9 +7,9 @@ # Configuration ENDPOINT_ID = os.getenv("ENDPOINT_ID", "your-endpoint-id") API_KEY = os.getenv("RUNPOD_API_KEY") -ENDPOINT_URL = f"https://{ENDPOINT_ID}.api.runpod.ai" +ENDPOINT_URL = os.getenv("ENDPOINT_URL", "http://localhost:8009") if os.getenv("ENDPOINT_URL") else f"https://{ENDPOINT_ID}.api.runpod.ai" -if not API_KEY: +if not API_KEY and not os.getenv("ENDPOINT_URL"): print("Error: Please set RUNPOD_API_KEY environment variable") sys.exit(1) @@ -218,7 +218,7 @@ def main(): print("๐Ÿงช vLLM Streaming Test Suite") print("=" * 50) print(f"๐ŸŽฏ Endpoint: {ENDPOINT_URL}") - print(f"๐Ÿ”‘ API Key: {API_KEY[:10]}...") + print(f"๐Ÿ”‘ API Key: {API_KEY[:10] if API_KEY else None}...") while True: print("\n" + "="*50) diff --git a/src/handler.py b/src/handler.py index 9f55f10..8b4b383 100644 --- a/src/handler.py +++ b/src/handler.py @@ -11,7 +11,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid from utils import format_chat_prompt, create_error_response -from .models import GenerationRequest, GenerationResponse, ChatCompletionRequest +from models import GenerationRequest, GenerationResponse, ChatCompletionRequest # Configure logging logging.basicConfig( @@ -53,24 +53,36 @@ async def create_engine(): global engine, engine_ready try: - # Get model name from environment variable - model_name = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium") - + # Get model path or model name registry from environment variables + model_name = os.getenv("MODEL_NAME",None) + served_model_name = os.getenv("SERVED_MODEL_NAME") + + logger.info(f"Model path: {model_name}, Served model name: {served_model_name}") + # Configure engine arguments engine_args = AsyncEngineArgs( model=model_name, + served_model_name=served_model_name, + download_dir=os.getenv("DOWNLOAD_DIR"), + # mm_processor_cache_gb not available in 0.9.1, should upgrade to use this paramater + # mm_processor_cache_gb= float(os.getenv("MM_PROCESSOR_CACHE_GB")) if os.getenv("MM_PROCESSOR_CACHE_GB") else None, + max_num_seqs=int(os.getenv("MAX_NUM_SEQS")) if os.getenv("MAX_NUM_SEQS") else None, + hf_token=os.getenv("HF_TOKEN"), tensor_parallel_size=int(os.getenv("TENSOR_PARALLEL_SIZE", "1")), dtype=os.getenv("DTYPE", "auto"), trust_remote_code=os.getenv("TRUST_REMOTE_CODE", "true").lower() == "true", max_model_len=int(os.getenv("MAX_MODEL_LEN")) if os.getenv("MAX_MODEL_LEN") else None, gpu_memory_utilization=float(os.getenv("GPU_MEMORY_UTILIZATION", "0.9")), enforce_eager=os.getenv("ENFORCE_EAGER", "false").lower() == "true", + enable_prefix_caching=bool(os.getenv("ENABLE_PREFIX_CACHING")) if os.getenv("ENABLE_PREFIX_CACHING") else None, + max_num_batched_tokens=int(os.getenv("MAX_NUM_BATCHED_TOKENS")) if os.getenv("MAX_NUM_BATCHED_TOKENS") else None, + ) # Create the engine engine = AsyncLLMEngine.from_engine_args(engine_args) engine_ready = True - logger.info(f"vLLM engine initialized successfully with model: {model_name}") + logger.info(f"vLLM engine initialized successfully with model: {model_name} (served as: {served_model_name})") except Exception as e: logger.error(f"Failed to initialize vLLM engine: {str(e)}") @@ -236,7 +248,7 @@ async def chat_completions(request: ChatCompletionRequest): return { "id": request_id, "object": "chat.completion", - "model": os.getenv("MODEL_NAME", "unknown"), + "model": os.getenv("SERVED_MODEL_NAME", "unknown"), "choices": [{ "index": 0, "message": { diff --git a/src/utils.py b/src/utils.py index 59ebb0d..ac35b81 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,6 +1,6 @@ from typing import List from transformers import AutoTokenizer -from .models import ChatMessage, ErrorResponse +from models import ChatMessage, ErrorResponse def get_tokenizer(model_name: str):