diff --git a/server/lorax_server/utils/sources/s3.py b/server/lorax_server/utils/sources/s3.py index 241c165f7..d0d71af6b 100644 --- a/server/lorax_server/utils/sources/s3.py +++ b/server/lorax_server/utils/sources/s3.py @@ -1,7 +1,7 @@ import os import time from datetime import timedelta -from typing import Optional, List, Any +from typing import TYPE_CHECKING, Optional, List, Any, Tuple from loguru import logger from pathlib import Path @@ -9,7 +9,6 @@ from botocore.config import Config from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE - from huggingface_hub.utils import ( LocalEntryNotFoundError, EntryNotFoundError, @@ -17,8 +16,41 @@ from .source import BaseModelSource, try_to_load_from_cache +if TYPE_CHECKING: + from boto3.resources.factory.s3 import Bucket + + +S3_PREFIX = "s3://" + + +def _get_bucket_and_model_id(model_id: str) -> Tuple[str, str]: + if model_id.startswith(S3_PREFIX): + model_id_no_protocol = model_id[len(S3_PREFIX) :] + if "/" not in model_id_no_protocol: + raise ValueError( + f"Invalid model_id {model_id}. " + f"model_id should be of the form `s3://bucket_name/model_id`" + ) + bucket_name, model_id = model_id_no_protocol.split("/", 1) + return bucket_name, model_id + + bucket = os.getenv("PREDIBASE_MODEL_BUCKET") + if not bucket: + # assume that the id preceding the first slash is the bucket name + if "/" not in model_id: + raise ValueError( + f"Invalid model_id {model_id}. " + f"model_id should be of the form `bucket_name/model_id` " + f"if PREDIBASE_MODEL_BUCKET environment variable is not set" + ) + + bucket_name, model_id = model_id.split("/", 1) + return bucket_name, model_id + + return bucket, model_id + -def _get_bucket_resource(): +def _get_bucket_resource(bucket_name: str) -> "Bucket": """Get the s3 client""" config = Config( retries=dict( @@ -27,10 +59,7 @@ def _get_bucket_resource(): ) ) s3 = boto3.resource('s3', config=config) - bucket = os.getenv("PREDIBASE_MODEL_BUCKET") - if not bucket: - raise ValueError("PREDIBASE_MODEL_BUCKET environment variable is not set") - return s3.Bucket(bucket) + return s3.Bucket(bucket_name) def get_s3_model_local_dir(model_id: str): @@ -172,10 +201,11 @@ def __init__(self, model_id: str, revision: Optional[str] = "", extension: str = raise ValueError(f"model_id '{model_id}' is too short for prefix filtering") # TODO: add support for revisions of the same model + bucket, model_id = _get_bucket_and_model_id(model_id) self.model_id = model_id self.revision = revision self.extension = extension - self.bucket = _get_bucket_resource() + self.bucket = _get_bucket_resource(bucket) def remote_weight_files(self, extension: str = None): extension = extension or self.extension diff --git a/server/tests/utils/test_s3.py b/server/tests/utils/test_s3.py new file mode 100644 index 000000000..17476a85e --- /dev/null +++ b/server/tests/utils/test_s3.py @@ -0,0 +1,47 @@ +import contextlib +import os +from typing import Optional + +import pytest + +from lorax_server.utils.sources.s3 import _get_bucket_and_model_id + + +@contextlib.contextmanager +def with_env_var(key: str, value: Optional[str]): + if value is None: + yield + return + + prev = os.environ.get(key) + try: + os.environ[key] = value + yield + finally: + if prev is None: + del os.environ[key] + else: + os.environ[key] = prev + + +@pytest.mark.parametrize( + "s3_path, env_var, expected_bucket, expected_model_id", + [ + ("s3://loras/foobar", None, "loras", "foobar"), + ("s3://loras/foo/bar", None, "loras", "foo/bar"), + ("s3://loras/foo/bar", "bucket", "loras", "foo/bar"), + ("loras/foobar", None, "loras", "foobar"), + ("loras/foo/bar", None, "loras", "foo/bar"), + ("loras/foo/bar", "bucket", "bucket", "loras/foo/bar"), + ] +) +def test_get_bucket_and_model_id( + s3_path: str, + env_var: Optional[str], + expected_bucket: str, + expected_model_id: str, +): + with with_env_var("PREDIBASE_MODEL_BUCKET", env_var): + bucket, model_id = _get_bucket_and_model_id(s3_path) + assert bucket == expected_bucket + assert model_id == expected_model_id