Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Use correct local path when loading base model from s3 #528

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from lorax_server.models.santacoder import SantaCoder
from lorax_server.models.seq2seq_lm import Seq2SeqLM
from lorax_server.models.t5 import T5Sharded
from lorax_server.utils.sources import get_s3_model_local_dir
from lorax_server.utils.sources.s3 import S3ModelSource

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
Expand Down Expand Up @@ -57,7 +57,8 @@ def get_model(
# change the model id to be the local path to the folder so
# we can load the config_dict locally
logger.info("Using the local files since we are coming from s3")
model_path = get_s3_model_local_dir(model_id)
model_source = S3ModelSource(model_id, revision)
model_path = model_source.get_local_path()
logger.info(f"model_path: {model_path}")
config_dict, _ = PretrainedConfig.get_config_dict(
model_path, revision=revision, trust_remote_code=trust_remote_code
Expand Down Expand Up @@ -96,7 +97,7 @@ def get_model(
from lorax_server.models.flash_bert import FlashBert

return FlashBert(model_id, revision=revision, dtype=dtype)

if model_type == "distilbert":
from lorax_server.models.flash_distilbert import FlashDistilBert

Expand Down
5 changes: 2 additions & 3 deletions server/lorax_server/utils/sources/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,8 @@ def download_weights(self, filenames: List[str]):
def download_model_assets(self):
return download_model_from_s3(self.bucket, self.model_id, self.extension)

def get_local_path(self, model_id: str):
_, model_id = _get_bucket_and_model_id(model_id)
return get_s3_model_local_dir(model_id)
def get_local_path(self):
return get_s3_model_local_dir(self.model_id)

def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]:
filenames = [filename]
Expand Down
Loading