diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 00c82fb77186c..a9f1e6e88d792 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn -import transformers from transformers import PretrainedConfig from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig @@ -108,39 +107,11 @@ def load_weights(self, model_name_or_path: str, **kwargs): neuronx_module = importlib.import_module(neuronx_module_path) neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - split_model_dir = f"{model_name_or_path}-split" - if _is_pretrained_neuron_checkpoint(model_name_or_path): - split_model_dir = model_name_or_path - elif not os.path.exists(f"{model_name_or_path}-split"): - hf_model_cls = getattr(transformers, hf_model_cls_name) - from transformers_neuronx.module import save_pretrained_split - - hf_model = hf_model_cls.from_pretrained(model_name_or_path, - low_cpu_mem_usage=True) - save_pretrained_split(hf_model, f"{model_name_or_path}-split") - - self.model = neuronx_model_cls.from_pretrained(split_model_dir, + self.model = neuronx_model_cls.from_pretrained(model_name_or_path, **kwargs) self.model.to_neuron() -def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool: - # Checking if the neuron checkpoint is saved in the old format. - if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")): - return True - # Checking if the neuron checkpoint is saved in the new format. - pretrained_split_files = ["config.json", "generation_config.json"] - pretrained_split_format = ".safetensors" - for file in pretrained_split_files: - file_path = os.path.join(model_name_or_path, file) - if not os.path.isfile(file_path): - return False - for file in os.listdir(model_name_or_path): - if file.endswith(pretrained_split_format): - return True - return False - - def _get_model_architecture(config: PretrainedConfig) -> str: architectures = getattr(config, "architectures", []) for arch in architectures: