From 81cd9f2d080d066d247af872bd1d9f54b676b5b6 Mon Sep 17 00:00:00 2001 From: Fiyi Adebekun Date: Thu, 27 Jun 2024 09:11:22 -0400 Subject: [PATCH 1/2] fix s3 base model --- server/lorax_server/models/__init__.py | 6 ++++-- server/lorax_server/utils/sources/s3.py | 5 ++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index 802fb40f4..514f87e51 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -15,6 +15,7 @@ from lorax_server.models.seq2seq_lm import Seq2SeqLM from lorax_server.models.t5 import T5Sharded from lorax_server.utils.sources import get_s3_model_local_dir +from lorax_server.utils.sources.s3 import S3ModelSource # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -57,7 +58,8 @@ def get_model( # change the model id to be the local path to the folder so # we can load the config_dict locally logger.info("Using the local files since we are coming from s3") - model_path = get_s3_model_local_dir(model_id) + model_source = S3ModelSource(model_id, revision) + model_path = model_source.get_local_path() logger.info(f"model_path: {model_path}") config_dict, _ = PretrainedConfig.get_config_dict( model_path, revision=revision, trust_remote_code=trust_remote_code @@ -96,7 +98,7 @@ def get_model( from lorax_server.models.flash_bert import FlashBert return FlashBert(model_id, revision=revision, dtype=dtype) - + if model_type == "distilbert": from lorax_server.models.flash_distilbert import FlashDistilBert diff --git a/server/lorax_server/utils/sources/s3.py b/server/lorax_server/utils/sources/s3.py index 085fbaeef..cf08d137c 100644 --- a/server/lorax_server/utils/sources/s3.py +++ b/server/lorax_server/utils/sources/s3.py @@ -227,9 +227,8 @@ def download_weights(self, filenames: List[str]): def download_model_assets(self): return download_model_from_s3(self.bucket, self.model_id, self.extension) - def get_local_path(self, model_id: str): - _, model_id = _get_bucket_and_model_id(model_id) - return get_s3_model_local_dir(model_id) + def get_local_path(self): + return get_s3_model_local_dir(self.model_id) def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: filenames = [filename] From 30800456cb81c1373c5fc8a12e7ef240dba0063b Mon Sep 17 00:00:00 2001 From: Fiyi Adebekun Date: Thu, 27 Jun 2024 09:55:04 -0400 Subject: [PATCH 2/2] lint --- server/lorax_server/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index 514f87e51..0652463d9 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -14,7 +14,6 @@ from lorax_server.models.santacoder import SantaCoder from lorax_server.models.seq2seq_lm import Seq2SeqLM from lorax_server.models.t5 import T5Sharded -from lorax_server.utils.sources import get_s3_model_local_dir from lorax_server.utils.sources.s3 import S3ModelSource # The flag below controls whether to allow TF32 on matmul. This flag defaults to False