diff --git a/aviary/backend/llm/predictor.py b/aviary/backend/llm/predictor.py index 40c21dec..3b83725e 100644 --- a/aviary/backend/llm/predictor.py +++ b/aviary/backend/llm/predictor.py @@ -324,7 +324,7 @@ async def _create_worker_group( await asyncio.gather( *[ initialize_node_remote_pg.remote( - llm_config.model_id, + llm_config.actual_hf_model_id, llm_config.initialization.s3_mirror_config, ) for i in range(scaling_config.num_workers) diff --git a/aviary/backend/llm/utils.py b/aviary/backend/llm/utils.py index b84bce4a..c64eadbe 100644 --- a/aviary/backend/llm/utils.py +++ b/aviary/backend/llm/utils.py @@ -32,7 +32,7 @@ def download_model( Download a model from an S3 bucket and save it in TRANSFORMERS_CACHE for seamless interoperability with Hugging Face's Transformers library. - The downloaded model must have a 'hash' file containing the commit hash corresponding + The downloaded model may have a 'hash' file containing the commit hash corresponding to the commit on Hugging Face Hub. """ from transformers.utils.hub import TRANSFORMERS_CACHE @@ -47,11 +47,13 @@ def download_model( + [os.path.join(bucket_uri, "hash"), "."] ) if not os.path.exists(os.path.join(".", "hash")): - raise RuntimeError( - "Hash file not found in the bucket or bucket could not have been downloaded." + f_hash = "0000000000000000000000000000000000000000" + logger.warning( + f"hash file does not exist in {bucket_uri}. Using {f_hash} as the hash." ) - with open(os.path.join(".", "hash"), "r") as f: - f_hash = f.read().strip() + else: + with open(os.path.join(".", "hash"), "r") as f: + f_hash = f.read().strip() logger.info( f"Downloading {model_id} from {bucket_uri} to {os.path.join(path, 'snapshots', f_hash)}" ) diff --git a/aviary/backend/server/models.py b/aviary/backend/server/models.py index 74890553..f7365172 100644 --- a/aviary/backend/server/models.py +++ b/aviary/backend/server/models.py @@ -305,6 +305,15 @@ def initializer_pipeline(cls, values): ) return values + @root_validator + def s3_mirror_config_transformers(cls, values): + s3_mirror_config: S3MirrorConfig = values.get("s3_mirror_config") + if s3_mirror_config.bucket_uri: + initializer: Initializer = values.get("initializer") + if isinstance(initializer, Transformers): + initializer.from_pretrained_kwargs["local_files_only"] = True + return values + class GenerationConfig(BaseModelExtended): prompt_format: Optional[str] = None