Skip to content

Commit

Permalink
Add unit tests to weight-utils
Browse files Browse the repository at this point in the history
Signed-off-by: OmerD <[email protected]>
  • Loading branch information
omer-dayan committed Dec 1, 2024
1 parent 7df80d3 commit db083db
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 2 deletions.
Empty file.
31 changes: 31 additions & 0 deletions tests/runai_model_streamer/test_runai_model_streamer_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from vllm import SamplingParams
from vllm.config import LoadConfig, LoadFormat
from vllm.model_executor.model_loader.loader import (RunaiModelStreamerLoader,
get_model_loader)

test_model = "openai-community/gpt2"

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)


def get_runai_model_loader():
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER)
return get_model_loader(load_config)


def test_get_model_loader_with_runai_flag():
model_loader = get_runai_model_loader()
assert isinstance(model_loader, RunaiModelStreamerLoader)


def test_runai_model_loader_download_files(vllm_runner):
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm:
deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs
39 changes: 39 additions & 0 deletions tests/runai_model_streamer/test_weight_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import glob
import tempfile

import huggingface_hub.constants
import torch

from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, runai_safetensors_weights_iterator,
safetensors_weights_iterator)


def test_runai_model_loader():
with tempfile.TemporaryDirectory() as tmpdir:
huggingface_hub.constants.HF_HUB_OFFLINE = False
download_weights_from_hf("openai-community/gpt2",
allow_patterns=["*.safetensors"],
cache_dir=tmpdir)
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
assert len(safetensors) > 0

runai_model_streamer_tensors = {}
hf_safetensors_tensors = {}

for name, tensor in runai_safetensors_weights_iterator(safetensors):
runai_model_streamer_tensors[name] = tensor

for name, tensor in safetensors_weights_iterator(safetensors):
hf_safetensors_tensors[name] = tensor

assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors)

for name, runai_tensor in runai_model_streamer_tensors.items():
assert runai_tensor.dtype == hf_safetensors_tensors[name].dtype
assert runai_tensor.shape == hf_safetensors_tensors[name].shape
assert torch.all(runai_tensor.eq(hf_safetensors_tensors[name]))


if __name__ == "__main__":
test_runai_model_loader()
6 changes: 4 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,9 +1321,11 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
with target_device:
model = _initialize_model(vllm_config=vllm_config)

assert hasattr(model_config, "model_weights")
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
model.load_weights(
self._get_weights_iterator(model_config.model_weights,
self._get_weights_iterator(model_weights,
model_config.revision))

for _, module in model.named_modules():
Expand Down

0 comments on commit db083db

Please sign in to comment.