diff --git a/tests/runai_model_streamer/__init__.py b/tests/runai_model_streamer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/runai_model_streamer/test_runai_model_streamer_loader.py b/tests/runai_model_streamer/test_runai_model_streamer_loader.py new file mode 100644 index 0000000000000..ce616d8504698 --- /dev/null +++ b/tests/runai_model_streamer/test_runai_model_streamer_loader.py @@ -0,0 +1,32 @@ +from vllm.model_executor.model_loader.loader import (get_model_loader, RunaiModelStreamerLoader) +from vllm.config import (LoadConfig, LoadFormat) +from ..conftest import VllmRunner +from vllm import SamplingParams + + +test_model = "openai-community/gpt2" + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) + +def get_runai_model_loader(): + load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER) + return get_model_loader(load_config) + +def test_get_model_loader_with_runai_flag(): + model_loader = get_runai_model_loader() + assert isinstance(model_loader, RunaiModelStreamerLoader) + +def test_runai_model_loader_download_files(vllm_runner: VllmRunner): + with vllm_runner( + test_model, + load_format=LoadFormat.RUNAI_STREAMER) as llm: + deserialized_outputs = llm.generate( + prompts, sampling_params) + assert deserialized_outputs diff --git a/tests/runai_model_streamer/test_weight_utils.py b/tests/runai_model_streamer/test_weight_utils.py new file mode 100644 index 0000000000000..da1308ac7aca0 --- /dev/null +++ b/tests/runai_model_streamer/test_weight_utils.py @@ -0,0 +1,38 @@ +import glob +import tempfile +import torch + +import huggingface_hub.constants + +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, runai_safetensors_weights_iterator, safetensors_weights_iterator) + +def test_runai_model_loader(): + with tempfile.TemporaryDirectory() as tmpdir: + huggingface_hub.constants.HF_HUB_OFFLINE = False + download_weights_from_hf("openai-community/gpt2", + allow_patterns=["*.safetensors"], + cache_dir=tmpdir) + safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) + assert len(safetensors) > 0 + + runai_model_streamer_tensors = {} + hf_safetensors_tensors = {} + + for name, tensor in runai_safetensors_weights_iterator(safetensors): + runai_model_streamer_tensors[name] = tensor + + for name, tensor in safetensors_weights_iterator(safetensors): + hf_safetensors_tensors[name] = tensor + + assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors) + + for name, runai_tensor in runai_model_streamer_tensors.items(): + assert runai_tensor.dtype == hf_safetensors_tensors[name].dtype + assert runai_tensor.shape == hf_safetensors_tensors[name].shape + assert torch.all(runai_tensor.eq(hf_safetensors_tensors[name])) + + + +if __name__ == "__main__": + test_runai_model_loader() diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b4bd5194cd63b..2bd874b632608 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1321,9 +1321,11 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: with target_device: model = _initialize_model(vllm_config=vllm_config) - assert hasattr(model_config, "model_weights") + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights model.load_weights( - self._get_weights_iterator(model_config.model_weights, + self._get_weights_iterator(model_weights, model_config.revision)) for _, module in model.named_modules():