diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index 802fb40f4..0652463d9 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -14,7 +14,7 @@ from lorax_server.models.santacoder import SantaCoder from lorax_server.models.seq2seq_lm import Seq2SeqLM from lorax_server.models.t5 import T5Sharded -from lorax_server.utils.sources import get_s3_model_local_dir +from lorax_server.utils.sources.s3 import S3ModelSource # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -57,7 +57,8 @@ def get_model( # change the model id to be the local path to the folder so # we can load the config_dict locally logger.info("Using the local files since we are coming from s3") - model_path = get_s3_model_local_dir(model_id) + model_source = S3ModelSource(model_id, revision) + model_path = model_source.get_local_path() logger.info(f"model_path: {model_path}") config_dict, _ = PretrainedConfig.get_config_dict( model_path, revision=revision, trust_remote_code=trust_remote_code @@ -96,7 +97,7 @@ def get_model( from lorax_server.models.flash_bert import FlashBert return FlashBert(model_id, revision=revision, dtype=dtype) - + if model_type == "distilbert": from lorax_server.models.flash_distilbert import FlashDistilBert diff --git a/server/lorax_server/utils/sources/s3.py b/server/lorax_server/utils/sources/s3.py index 085fbaeef..cf08d137c 100644 --- a/server/lorax_server/utils/sources/s3.py +++ b/server/lorax_server/utils/sources/s3.py @@ -227,9 +227,8 @@ def download_weights(self, filenames: List[str]): def download_model_assets(self): return download_model_from_s3(self.bucket, self.model_id, self.extension) - def get_local_path(self, model_id: str): - _, model_id = _get_bucket_and_model_id(model_id) - return get_s3_model_local_dir(model_id) + def get_local_path(self): + return get_s3_model_local_dir(self.model_id) def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: filenames = [filename]