diff --git a/tests/runai_model_streamer/__init__.py b/tests/runai_model_streamer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/runai_model_streamer/test_runai_model_streamer_loader.py b/tests/runai_model_streamer/test_runai_model_streamer_loader.py new file mode 100644 index 0000000000000..c5722fbae5c8a --- /dev/null +++ b/tests/runai_model_streamer/test_runai_model_streamer_loader.py @@ -0,0 +1,31 @@ +from vllm import SamplingParams +from vllm.config import LoadConfig, LoadFormat +from vllm.model_executor.model_loader.loader import (RunaiModelStreamerLoader, + get_model_loader) + +test_model = "openai-community/gpt2" + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) + + +def get_runai_model_loader(): + load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER) + return get_model_loader(load_config) + + +def test_get_model_loader_with_runai_flag(): + model_loader = get_runai_model_loader() + assert isinstance(model_loader, RunaiModelStreamerLoader) + + +def test_runai_model_loader_download_files(vllm_runner): + with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm: + deserialized_outputs = llm.generate(prompts, sampling_params) + assert deserialized_outputs diff --git a/tests/runai_model_streamer/test_weight_utils.py b/tests/runai_model_streamer/test_weight_utils.py new file mode 100644 index 0000000000000..5c89bd78ad81d --- /dev/null +++ b/tests/runai_model_streamer/test_weight_utils.py @@ -0,0 +1,39 @@ +import glob +import tempfile + +import huggingface_hub.constants +import torch + +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, runai_safetensors_weights_iterator, + safetensors_weights_iterator) + + +def test_runai_model_loader(): + with tempfile.TemporaryDirectory() as tmpdir: + huggingface_hub.constants.HF_HUB_OFFLINE = False + download_weights_from_hf("openai-community/gpt2", + allow_patterns=["*.safetensors"], + cache_dir=tmpdir) + safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) + assert len(safetensors) > 0 + + runai_model_streamer_tensors = {} + hf_safetensors_tensors = {} + + for name, tensor in runai_safetensors_weights_iterator(safetensors): + runai_model_streamer_tensors[name] = tensor + + for name, tensor in safetensors_weights_iterator(safetensors): + hf_safetensors_tensors[name] = tensor + + assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors) + + for name, runai_tensor in runai_model_streamer_tensors.items(): + assert runai_tensor.dtype == hf_safetensors_tensors[name].dtype + assert runai_tensor.shape == hf_safetensors_tensors[name].shape + assert torch.all(runai_tensor.eq(hf_safetensors_tensors[name])) + + +if __name__ == "__main__": + test_runai_model_loader() diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b4bd5194cd63b..2bd874b632608 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1321,9 +1321,11 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: with target_device: model = _initialize_model(vllm_config=vllm_config) - assert hasattr(model_config, "model_weights") + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights model.load_weights( - self._get_weights_iterator(model_config.model_weights, + self._get_weights_iterator(model_weights, model_config.revision)) for _, module in model.named_modules():