Skip to content

Commit

Permalink
Add unit tests to weight-utils
Browse files Browse the repository at this point in the history
Signed-off-by: OmerD <[email protected]>
  • Loading branch information
omer-dayan committed Dec 1, 2024
1 parent 7df80d3 commit 554c859
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 2 deletions.
Empty file.
32 changes: 32 additions & 0 deletions tests/runai_model_streamer/test_runai_model_streamer_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from vllm.model_executor.model_loader.loader import (get_model_loader, RunaiModelStreamerLoader)

Check failure on line 1 in tests/runai_model_streamer/test_runai_model_streamer_loader.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/runai_model_streamer/test_runai_model_streamer_loader.py:1:81: E501 Line too long (96 > 80)
from vllm.config import (LoadConfig, LoadFormat)
from ..conftest import VllmRunner
from vllm import SamplingParams


test_model = "openai-community/gpt2"

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)

def get_runai_model_loader():
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER)
return get_model_loader(load_config)

def test_get_model_loader_with_runai_flag():
model_loader = get_runai_model_loader()
assert isinstance(model_loader, RunaiModelStreamerLoader)

def test_runai_model_loader_download_files(vllm_runner: VllmRunner):
with vllm_runner(

Check failure on line 27 in tests/runai_model_streamer/test_runai_model_streamer_loader.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

"VllmRunner" not callable [operator]

Check failure on line 27 in tests/runai_model_streamer/test_runai_model_streamer_loader.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

"VllmRunner" not callable [operator]

Check failure on line 27 in tests/runai_model_streamer/test_runai_model_streamer_loader.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

"VllmRunner" not callable [operator]

Check failure on line 27 in tests/runai_model_streamer/test_runai_model_streamer_loader.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

"VllmRunner" not callable [operator]
test_model,
load_format=LoadFormat.RUNAI_STREAMER) as llm:
deserialized_outputs = llm.generate(
prompts, sampling_params)
assert deserialized_outputs
38 changes: 38 additions & 0 deletions tests/runai_model_streamer/test_weight_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import glob
import tempfile
import torch

import huggingface_hub.constants

from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, runai_safetensors_weights_iterator, safetensors_weights_iterator)

Check failure on line 8 in tests/runai_model_streamer/test_weight_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/runai_model_streamer/test_weight_utils.py:8:81: E501 Line too long (95 > 80)

def test_runai_model_loader():
with tempfile.TemporaryDirectory() as tmpdir:
huggingface_hub.constants.HF_HUB_OFFLINE = False
download_weights_from_hf("openai-community/gpt2",
allow_patterns=["*.safetensors"],
cache_dir=tmpdir)
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
assert len(safetensors) > 0

runai_model_streamer_tensors = {}
hf_safetensors_tensors = {}

for name, tensor in runai_safetensors_weights_iterator(safetensors):
runai_model_streamer_tensors[name] = tensor

for name, tensor in safetensors_weights_iterator(safetensors):
hf_safetensors_tensors[name] = tensor

assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors)

for name, runai_tensor in runai_model_streamer_tensors.items():
assert runai_tensor.dtype == hf_safetensors_tensors[name].dtype
assert runai_tensor.shape == hf_safetensors_tensors[name].shape
assert torch.all(runai_tensor.eq(hf_safetensors_tensors[name]))



if __name__ == "__main__":
test_runai_model_loader()
6 changes: 4 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,9 +1321,11 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
with target_device:
model = _initialize_model(vllm_config=vllm_config)

assert hasattr(model_config, "model_weights")
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
model.load_weights(
self._get_weights_iterator(model_config.model_weights,
self._get_weights_iterator(model_weights,
model_config.revision))

for _, module in model.named_modules():
Expand Down

0 comments on commit 554c859

Please sign in to comment.