Skip to content

Commit

Permalink
[Hardwware][Neuron] Simplify model load for transformers-neuronx libr…
Browse files Browse the repository at this point in the history
  • Loading branch information
sssrijan-amazon authored Oct 17, 2024
1 parent d615b5c commit bb76538
Showing 1 changed file with 1 addition and 30 deletions.
31 changes: 1 addition & 30 deletions vllm/model_executor/model_loader/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch
import torch.nn as nn
import transformers
from transformers import PretrainedConfig

from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
Expand Down Expand Up @@ -108,39 +107,11 @@ def load_weights(self, model_name_or_path: str, **kwargs):
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)

split_model_dir = f"{model_name_or_path}-split"
if _is_pretrained_neuron_checkpoint(model_name_or_path):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
hf_model_cls = getattr(transformers, hf_model_cls_name)
from transformers_neuronx.module import save_pretrained_split

hf_model = hf_model_cls.from_pretrained(model_name_or_path,
low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")

self.model = neuronx_model_cls.from_pretrained(split_model_dir,
self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
**kwargs)
self.model.to_neuron()


def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool:
# Checking if the neuron checkpoint is saved in the old format.
if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")):
return True
# Checking if the neuron checkpoint is saved in the new format.
pretrained_split_files = ["config.json", "generation_config.json"]
pretrained_split_format = ".safetensors"
for file in pretrained_split_files:
file_path = os.path.join(model_name_or_path, file)
if not os.path.isfile(file_path):
return False
for file in os.listdir(model_name_or_path):
if file.endswith(pretrained_split_format):
return True
return False


def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
Expand Down

0 comments on commit bb76538

Please sign in to comment.